Skip to main content

tinyquant_core/codec/
codebook.rs

1//! Immutable quantization lookup table with Python-parity training.
2//!
3//! A [`Codebook`] owns `2^bit_width` sorted, strictly-ascending `f32`
4//! entries — one representative value per quantized index. It mirrors
5//! `tinyquant_cpu.codec.codebook.Codebook` field-for-field and produces
6//! byte-identical output on the full training + quantize + dequantize
7//! pipeline for the same inputs.
8//!
9//! The algorithmic contract is fixed by
10//! `docs/design/rust/numerical-semantics.md` §Quantization and
11//! §Codebook training. Changing anything here requires regenerating
12//! the frozen fixtures under `tests/fixtures/codebook/` and
13//! `tests/fixtures/quantize/` via `cargo xtask fixtures refresh-all`.
14//!
15//! Under `feature = "simd"` the quantize / dequantize methods route
16//! through [`crate::codec::simd_api`], which is the single source of
17//! truth for dispatch selection. No `unsafe` intrinsics are invoked
18//! directly from this file, but the file-level `#![allow(unsafe_code)]`
19//! is retained so that future SIMD specialisation (Phase 22) can add
20//! direct intrinsic paths here without re-plumbing the lint surface.
21#![allow(unsafe_code)]
22
23use alloc::boxed::Box;
24use alloc::sync::Arc;
25use alloc::vec;
26use alloc::vec::Vec;
27use core::cmp::Ordering;
28use core::fmt;
29
30use crate::codec::codec_config::CodecConfig;
31#[cfg(not(feature = "simd"))]
32use crate::codec::kernels::scalar as scalar_kernel;
33use crate::errors::CodecError;
34
35/// Immutable lookup table mapping quantized `u8` indices to `f32` values.
36///
37/// Construction always validates three invariants:
38///
39/// 1. `entries.len() == 1 << bit_width` (matches `CodecConfig::num_codebook_entries`).
40/// 2. Entries are strictly ascending under `f32::total_cmp`.
41/// 3. All entries are distinct (no adjacent equals).
42///
43/// The inner buffer is an `Arc<[f32]>` so `Clone` is `O(1)`. Equality
44/// compares the numerical contents, not the allocation identity.
45#[derive(Clone)]
46pub struct Codebook {
47    entries: Arc<[f32]>,
48    bit_width: u8,
49}
50
51impl Codebook {
52    /// Build a codebook from a caller-owned `Box<[f32]>`.
53    ///
54    /// # Errors
55    ///
56    /// * [`CodecError::CodebookEntryCount`] — `entries.len()` does not
57    ///   equal `2^bit_width`.
58    /// * [`CodecError::CodebookNotSorted`] — entries are not in
59    ///   strictly ascending order (excluding duplicates).
60    /// * [`CodecError::CodebookDuplicate`] — two adjacent entries
61    ///   compare equal.
62    pub fn new(entries: Box<[f32]>, bit_width: u8) -> Result<Self, CodecError> {
63        let expected = 1u32
64            .checked_shl(u32::from(bit_width))
65            .ok_or(CodecError::UnsupportedBitWidth { got: bit_width })?;
66        let got = u32::try_from(entries.len()).map_err(|_| CodecError::CodebookEntryCount {
67            expected,
68            got: u32::MAX,
69            bit_width,
70        })?;
71        if got != expected {
72            return Err(CodecError::CodebookEntryCount {
73                expected,
74                got,
75                bit_width,
76            });
77        }
78
79        // Validate monotonicity and distinctness in a single pass without
80        // indexing, to satisfy `clippy::indexing_slicing`.
81        let mut distinct: u32 = u32::from(!entries.is_empty());
82        let mut prev: Option<f32> = None;
83        for &value in &*entries {
84            if let Some(p) = prev {
85                match f32::total_cmp(&p, &value) {
86                    Ordering::Less => distinct += 1,
87                    Ordering::Equal => {
88                        // Count distinct entries; `expected - distinct`
89                        // is the minimum number of duplicates we must
90                        // still resolve to make the codebook well-formed.
91                        return Err(CodecError::CodebookDuplicate {
92                            expected,
93                            got: distinct,
94                        });
95                    }
96                    Ordering::Greater => return Err(CodecError::CodebookNotSorted),
97                }
98            }
99            prev = Some(value);
100        }
101
102        Ok(Self {
103            entries: Arc::from(entries),
104            bit_width,
105        })
106    }
107
108    /// Train a codebook by uniform-quantile estimation over a flattened
109    /// f32 sample buffer.
110    ///
111    /// Mirrors Python's
112    /// `np.quantile(flat.astype(np.float64),
113    /// np.linspace(0, 1, num_entries)).astype(np.float32)` exactly:
114    ///
115    /// 1. Promote every sample to `f64`.
116    /// 2. Sort with `f64::total_cmp`.
117    /// 3. For each `k` in `0..num_entries`, compute the linearly-
118    ///    interpolated quantile value in `f64`.
119    /// 4. Cast to `f32` (round-to-nearest-even) and enforce distinctness.
120    ///
121    /// `config.bit_width` determines the number of entries; `config.seed`
122    /// and `config.dimension` are not consulted by this function.
123    ///
124    /// # Errors
125    ///
126    /// * [`CodecError::InsufficientTrainingData`] — `vectors` is empty
127    ///   or produces fewer than `num_entries` distinct quantile
128    ///   representatives.
129    /// * Any error from [`Codebook::new`] on the freshly-built entries.
130    #[allow(
131        clippy::cast_precision_loss,
132        clippy::cast_possible_truncation,
133        clippy::cast_sign_loss
134    )]
135    pub fn train(vectors: &[f32], config: &CodecConfig) -> Result<Self, CodecError> {
136        let num_entries_u32 = config.num_codebook_entries();
137        let num_entries = num_entries_u32 as usize;
138        if vectors.is_empty() {
139            return Err(CodecError::InsufficientTrainingData {
140                expected: num_entries_u32,
141            });
142        }
143
144        // Step 1: promote to f64.
145        let mut flat: Vec<f64> = vectors.iter().copied().map(f64::from).collect();
146
147        // Step 2: sort with total ordering (matches NumPy's sort on non-
148        // NaN data; also makes NaNs deterministic if they ever sneak in).
149        flat.sort_by(f64::total_cmp);
150
151        // Step 3: quantile interpolation.
152        let len = flat.len();
153        let last_idx = len.saturating_sub(1);
154        let num_entries_minus_one =
155            num_entries
156                .checked_sub(1)
157                .ok_or(CodecError::InsufficientTrainingData {
158                    expected: num_entries_u32,
159                })?;
160        let divisor = num_entries_minus_one as f64;
161        let span = last_idx as f64;
162
163        let mut entries_f32: Vec<f32> = Vec::with_capacity(num_entries);
164        for k in 0..num_entries {
165            let q = (k as f64) / divisor;
166            let h = q * span;
167            let floor_h = libm::floor(h);
168            let frac = h - floor_h;
169            let i = floor_h as usize;
170            let i_plus_one = i.saturating_add(1).min(last_idx);
171            let lo = *flat.get(i).ok_or(CodecError::InsufficientTrainingData {
172                expected: num_entries_u32,
173            })?;
174            let hi = *flat
175                .get(i_plus_one)
176                .ok_or(CodecError::InsufficientTrainingData {
177                    expected: num_entries_u32,
178                })?;
179            let value_f64 = lo + frac * (hi - lo);
180            entries_f32.push(value_f64 as f32);
181        }
182
183        // Step 4: defensive sort (NumPy's linspace-driven quantiles are
184        // already monotone, but explicit sorting guards against any
185        // future change to the interpolation recipe).
186        entries_f32.sort_by(f32::total_cmp);
187
188        // Distinctness check — surface a dedicated error before handing
189        // off to `Codebook::new`, because the insufficient-training-data
190        // story is more informative than a raw duplicate-entries error
191        // in this context.
192        let mut distinct: u32 = 1;
193        let mut iter = entries_f32.iter();
194        if let Some(first) = iter.next() {
195            let mut prev = *first;
196            for &value in iter {
197                if f32::total_cmp(&prev, &value) == Ordering::Less {
198                    distinct += 1;
199                }
200                prev = value;
201            }
202        }
203        if distinct < num_entries_u32 {
204            return Err(CodecError::InsufficientTrainingData {
205                expected: num_entries_u32,
206            });
207        }
208
209        // Delegating to `Codebook::new` re-runs the invariant checks so
210        // the constructor stays the single source of truth.
211        Self::new(entries_f32.into_boxed_slice(), config.bit_width())
212    }
213
214    /// Number of entries (`2^bit_width`).
215    #[inline]
216    pub fn num_entries(&self) -> u32 {
217        // Constructor guarantees the length fits in `u32`.
218        u32::try_from(self.entries.len()).unwrap_or(u32::MAX)
219    }
220
221    /// The bit width this codebook was built for.
222    #[inline]
223    pub const fn bit_width(&self) -> u8 {
224        self.bit_width
225    }
226
227    /// Borrow the underlying sorted entries.
228    #[inline]
229    pub fn entries(&self) -> &[f32] {
230        &self.entries
231    }
232
233    /// Quantize `values` into `indices` by finding the nearest entry for
234    /// each value. Ties favor the right (higher-valued) neighbor, matching
235    /// Python's strict `<` tie-break.
236    ///
237    /// Under `feature = "simd"` this delegates to
238    /// [`crate::codec::simd_api::quantize_into`], which is the single
239    /// source of truth for dispatch selection. Without the feature,
240    /// it calls the scalar reference kernel directly.
241    ///
242    /// # Errors
243    ///
244    /// * [`CodecError::LengthMismatch`] — `values.len() != indices.len()`.
245    pub fn quantize_into(&self, values: &[f32], indices: &mut [u8]) -> Result<(), CodecError> {
246        let entries = &self.entries;
247        #[cfg(feature = "simd")]
248        {
249            crate::codec::simd_api::quantize_into(entries, values, indices)
250        }
251        #[cfg(not(feature = "simd"))]
252        {
253            scalar_kernel::quantize_into(entries, values, indices)
254        }
255    }
256
257    /// Dequantize `indices` into `values` by gathering the corresponding
258    /// codebook entries.
259    ///
260    /// Under `feature = "simd"` this delegates to
261    /// [`crate::codec::simd_api::dequantize_into`].
262    ///
263    /// # Errors
264    ///
265    /// * [`CodecError::LengthMismatch`] — `indices.len() != values.len()`.
266    /// * [`CodecError::IndexOutOfRange`] — any index is
267    ///   `>= num_entries()`.
268    pub fn dequantize_into(&self, indices: &[u8], values: &mut [f32]) -> Result<(), CodecError> {
269        let entries = &self.entries;
270        #[cfg(feature = "simd")]
271        {
272            crate::codec::simd_api::dequantize_into(entries, indices, values)
273        }
274        #[cfg(not(feature = "simd"))]
275        {
276            scalar_kernel::dequantize_into(entries, indices, values)
277        }
278    }
279
280    /// Convenience: allocate and return the quantized indices.
281    ///
282    /// # Errors
283    ///
284    /// See [`Codebook::quantize_into`].
285    pub fn quantize(&self, values: &[f32]) -> Result<Vec<u8>, CodecError> {
286        let mut out = vec![0u8; values.len()];
287        self.quantize_into(values, &mut out)?;
288        Ok(out)
289    }
290
291    /// Convenience: allocate and return the dequantized values.
292    ///
293    /// # Errors
294    ///
295    /// See [`Codebook::dequantize_into`].
296    pub fn dequantize(&self, indices: &[u8]) -> Result<Vec<f32>, CodecError> {
297        let mut out = vec![0.0f32; indices.len()];
298        self.dequantize_into(indices, &mut out)?;
299        Ok(out)
300    }
301}
302
303impl fmt::Debug for Codebook {
304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305        f.debug_struct("Codebook")
306            .field("bit_width", &self.bit_width)
307            .field("num_entries", &self.num_entries())
308            .field("entries", &self.entries)
309            .finish()
310    }
311}
312
313impl PartialEq for Codebook {
314    fn eq(&self, other: &Self) -> bool {
315        if self.bit_width != other.bit_width {
316            return false;
317        }
318        if self.entries.len() != other.entries.len() {
319            return false;
320        }
321        self.entries
322            .iter()
323            .zip(other.entries.iter())
324            .all(|(a, b)| a.to_bits() == b.to_bits())
325    }
326}