tinyquant_core/codec/codebook.rs
1//! Immutable quantization lookup table with Python-parity training.
2//!
3//! A [`Codebook`] owns `2^bit_width` sorted, strictly-ascending `f32`
4//! entries — one representative value per quantized index. It mirrors
5//! `tinyquant_cpu.codec.codebook.Codebook` field-for-field and produces
6//! byte-identical output on the full training + quantize + dequantize
7//! pipeline for the same inputs.
8//!
9//! The algorithmic contract is fixed by
10//! `docs/design/rust/numerical-semantics.md` §Quantization and
11//! §Codebook training. Changing anything here requires regenerating
12//! the frozen fixtures under `tests/fixtures/codebook/` and
13//! `tests/fixtures/quantize/` via `cargo xtask fixtures refresh-all`.
14//!
15//! Under `feature = "simd"` the quantize / dequantize methods route
16//! through [`crate::codec::simd_api`], which is the single source of
17//! truth for dispatch selection. No `unsafe` intrinsics are invoked
18//! directly from this file, but the file-level `#![allow(unsafe_code)]`
19//! is retained so that future SIMD specialisation (Phase 22) can add
20//! direct intrinsic paths here without re-plumbing the lint surface.
21#![allow(unsafe_code)]
22
23use alloc::boxed::Box;
24use alloc::sync::Arc;
25use alloc::vec;
26use alloc::vec::Vec;
27use core::cmp::Ordering;
28use core::fmt;
29
30use crate::codec::codec_config::CodecConfig;
31#[cfg(not(feature = "simd"))]
32use crate::codec::kernels::scalar as scalar_kernel;
33use crate::errors::CodecError;
34
35/// Immutable lookup table mapping quantized `u8` indices to `f32` values.
36///
37/// Construction always validates three invariants:
38///
39/// 1. `entries.len() == 1 << bit_width` (matches `CodecConfig::num_codebook_entries`).
40/// 2. Entries are strictly ascending under `f32::total_cmp`.
41/// 3. All entries are distinct (no adjacent equals).
42///
43/// The inner buffer is an `Arc<[f32]>` so `Clone` is `O(1)`. Equality
44/// compares the numerical contents, not the allocation identity.
45#[derive(Clone)]
46pub struct Codebook {
47 entries: Arc<[f32]>,
48 bit_width: u8,
49}
50
51impl Codebook {
52 /// Build a codebook from a caller-owned `Box<[f32]>`.
53 ///
54 /// # Errors
55 ///
56 /// * [`CodecError::CodebookEntryCount`] — `entries.len()` does not
57 /// equal `2^bit_width`.
58 /// * [`CodecError::CodebookNotSorted`] — entries are not in
59 /// strictly ascending order (excluding duplicates).
60 /// * [`CodecError::CodebookDuplicate`] — two adjacent entries
61 /// compare equal.
62 pub fn new(entries: Box<[f32]>, bit_width: u8) -> Result<Self, CodecError> {
63 let expected = 1u32
64 .checked_shl(u32::from(bit_width))
65 .ok_or(CodecError::UnsupportedBitWidth { got: bit_width })?;
66 let got = u32::try_from(entries.len()).map_err(|_| CodecError::CodebookEntryCount {
67 expected,
68 got: u32::MAX,
69 bit_width,
70 })?;
71 if got != expected {
72 return Err(CodecError::CodebookEntryCount {
73 expected,
74 got,
75 bit_width,
76 });
77 }
78
79 // Validate monotonicity and distinctness in a single pass without
80 // indexing, to satisfy `clippy::indexing_slicing`.
81 let mut distinct: u32 = u32::from(!entries.is_empty());
82 let mut prev: Option<f32> = None;
83 for &value in &*entries {
84 if let Some(p) = prev {
85 match f32::total_cmp(&p, &value) {
86 Ordering::Less => distinct += 1,
87 Ordering::Equal => {
88 // Count distinct entries; `expected - distinct`
89 // is the minimum number of duplicates we must
90 // still resolve to make the codebook well-formed.
91 return Err(CodecError::CodebookDuplicate {
92 expected,
93 got: distinct,
94 });
95 }
96 Ordering::Greater => return Err(CodecError::CodebookNotSorted),
97 }
98 }
99 prev = Some(value);
100 }
101
102 Ok(Self {
103 entries: Arc::from(entries),
104 bit_width,
105 })
106 }
107
108 /// Train a codebook by uniform-quantile estimation over a flattened
109 /// f32 sample buffer.
110 ///
111 /// Mirrors Python's
112 /// `np.quantile(flat.astype(np.float64),
113 /// np.linspace(0, 1, num_entries)).astype(np.float32)` exactly:
114 ///
115 /// 1. Promote every sample to `f64`.
116 /// 2. Sort with `f64::total_cmp`.
117 /// 3. For each `k` in `0..num_entries`, compute the linearly-
118 /// interpolated quantile value in `f64`.
119 /// 4. Cast to `f32` (round-to-nearest-even) and enforce distinctness.
120 ///
121 /// `config.bit_width` determines the number of entries; `config.seed`
122 /// and `config.dimension` are not consulted by this function.
123 ///
124 /// # Errors
125 ///
126 /// * [`CodecError::InsufficientTrainingData`] — `vectors` is empty
127 /// or produces fewer than `num_entries` distinct quantile
128 /// representatives.
129 /// * Any error from [`Codebook::new`] on the freshly-built entries.
130 #[allow(
131 clippy::cast_precision_loss,
132 clippy::cast_possible_truncation,
133 clippy::cast_sign_loss
134 )]
135 pub fn train(vectors: &[f32], config: &CodecConfig) -> Result<Self, CodecError> {
136 let num_entries_u32 = config.num_codebook_entries();
137 let num_entries = num_entries_u32 as usize;
138 if vectors.is_empty() {
139 return Err(CodecError::InsufficientTrainingData {
140 expected: num_entries_u32,
141 });
142 }
143
144 // Step 1: promote to f64.
145 let mut flat: Vec<f64> = vectors.iter().copied().map(f64::from).collect();
146
147 // Step 2: sort with total ordering (matches NumPy's sort on non-
148 // NaN data; also makes NaNs deterministic if they ever sneak in).
149 flat.sort_by(f64::total_cmp);
150
151 // Step 3: quantile interpolation.
152 let len = flat.len();
153 let last_idx = len.saturating_sub(1);
154 let num_entries_minus_one =
155 num_entries
156 .checked_sub(1)
157 .ok_or(CodecError::InsufficientTrainingData {
158 expected: num_entries_u32,
159 })?;
160 let divisor = num_entries_minus_one as f64;
161 let span = last_idx as f64;
162
163 let mut entries_f32: Vec<f32> = Vec::with_capacity(num_entries);
164 for k in 0..num_entries {
165 let q = (k as f64) / divisor;
166 let h = q * span;
167 let floor_h = libm::floor(h);
168 let frac = h - floor_h;
169 let i = floor_h as usize;
170 let i_plus_one = i.saturating_add(1).min(last_idx);
171 let lo = *flat.get(i).ok_or(CodecError::InsufficientTrainingData {
172 expected: num_entries_u32,
173 })?;
174 let hi = *flat
175 .get(i_plus_one)
176 .ok_or(CodecError::InsufficientTrainingData {
177 expected: num_entries_u32,
178 })?;
179 let value_f64 = lo + frac * (hi - lo);
180 entries_f32.push(value_f64 as f32);
181 }
182
183 // Step 4: defensive sort (NumPy's linspace-driven quantiles are
184 // already monotone, but explicit sorting guards against any
185 // future change to the interpolation recipe).
186 entries_f32.sort_by(f32::total_cmp);
187
188 // Distinctness check — surface a dedicated error before handing
189 // off to `Codebook::new`, because the insufficient-training-data
190 // story is more informative than a raw duplicate-entries error
191 // in this context.
192 let mut distinct: u32 = 1;
193 let mut iter = entries_f32.iter();
194 if let Some(first) = iter.next() {
195 let mut prev = *first;
196 for &value in iter {
197 if f32::total_cmp(&prev, &value) == Ordering::Less {
198 distinct += 1;
199 }
200 prev = value;
201 }
202 }
203 if distinct < num_entries_u32 {
204 return Err(CodecError::InsufficientTrainingData {
205 expected: num_entries_u32,
206 });
207 }
208
209 // Delegating to `Codebook::new` re-runs the invariant checks so
210 // the constructor stays the single source of truth.
211 Self::new(entries_f32.into_boxed_slice(), config.bit_width())
212 }
213
214 /// Number of entries (`2^bit_width`).
215 #[inline]
216 pub fn num_entries(&self) -> u32 {
217 // Constructor guarantees the length fits in `u32`.
218 u32::try_from(self.entries.len()).unwrap_or(u32::MAX)
219 }
220
221 /// The bit width this codebook was built for.
222 #[inline]
223 pub const fn bit_width(&self) -> u8 {
224 self.bit_width
225 }
226
227 /// Borrow the underlying sorted entries.
228 #[inline]
229 pub fn entries(&self) -> &[f32] {
230 &self.entries
231 }
232
233 /// Quantize `values` into `indices` by finding the nearest entry for
234 /// each value. Ties favor the right (higher-valued) neighbor, matching
235 /// Python's strict `<` tie-break.
236 ///
237 /// Under `feature = "simd"` this delegates to
238 /// [`crate::codec::simd_api::quantize_into`], which is the single
239 /// source of truth for dispatch selection. Without the feature,
240 /// it calls the scalar reference kernel directly.
241 ///
242 /// # Errors
243 ///
244 /// * [`CodecError::LengthMismatch`] — `values.len() != indices.len()`.
245 pub fn quantize_into(&self, values: &[f32], indices: &mut [u8]) -> Result<(), CodecError> {
246 let entries = &self.entries;
247 #[cfg(feature = "simd")]
248 {
249 crate::codec::simd_api::quantize_into(entries, values, indices)
250 }
251 #[cfg(not(feature = "simd"))]
252 {
253 scalar_kernel::quantize_into(entries, values, indices)
254 }
255 }
256
257 /// Dequantize `indices` into `values` by gathering the corresponding
258 /// codebook entries.
259 ///
260 /// Under `feature = "simd"` this delegates to
261 /// [`crate::codec::simd_api::dequantize_into`].
262 ///
263 /// # Errors
264 ///
265 /// * [`CodecError::LengthMismatch`] — `indices.len() != values.len()`.
266 /// * [`CodecError::IndexOutOfRange`] — any index is
267 /// `>= num_entries()`.
268 pub fn dequantize_into(&self, indices: &[u8], values: &mut [f32]) -> Result<(), CodecError> {
269 let entries = &self.entries;
270 #[cfg(feature = "simd")]
271 {
272 crate::codec::simd_api::dequantize_into(entries, indices, values)
273 }
274 #[cfg(not(feature = "simd"))]
275 {
276 scalar_kernel::dequantize_into(entries, indices, values)
277 }
278 }
279
280 /// Convenience: allocate and return the quantized indices.
281 ///
282 /// # Errors
283 ///
284 /// See [`Codebook::quantize_into`].
285 pub fn quantize(&self, values: &[f32]) -> Result<Vec<u8>, CodecError> {
286 let mut out = vec![0u8; values.len()];
287 self.quantize_into(values, &mut out)?;
288 Ok(out)
289 }
290
291 /// Convenience: allocate and return the dequantized values.
292 ///
293 /// # Errors
294 ///
295 /// See [`Codebook::dequantize_into`].
296 pub fn dequantize(&self, indices: &[u8]) -> Result<Vec<f32>, CodecError> {
297 let mut out = vec![0.0f32; indices.len()];
298 self.dequantize_into(indices, &mut out)?;
299 Ok(out)
300 }
301}
302
303impl fmt::Debug for Codebook {
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 f.debug_struct("Codebook")
306 .field("bit_width", &self.bit_width)
307 .field("num_entries", &self.num_entries())
308 .field("entries", &self.entries)
309 .finish()
310 }
311}
312
313impl PartialEq for Codebook {
314 fn eq(&self, other: &Self) -> bool {
315 if self.bit_width != other.bit_width {
316 return false;
317 }
318 if self.entries.len() != other.entries.len() {
319 return false;
320 }
321 self.entries
322 .iter()
323 .zip(other.entries.iter())
324 .all(|(a, b)| a.to_bits() == b.to_bits())
325 }
326}