Skip to main content

tinyquant_core/codec/
codec_config.rs

1//! `CodecConfig`: immutable value object describing codec parameters.
2//!
3//! Mirrors `tinyquant_cpu.codec.codec_config.CodecConfig` field-for-field.
4//! The canonical `config_hash` format is a SHA-256 digest of a fixed
5//! canonical string and matches the Python reference byte-for-byte so
6//! artifacts produced by either implementation are interchangeable.
7//!
8//! The canonical string format is:
9//!
10//! ```text
11//! CodecConfig(bit_width={b},seed={s},dimension={d},residual_enabled={r})
12//! ```
13//!
14//! where `{r}` stringifies the bool as Python's `str(bool)` spelling —
15//! `"True"` or `"False"` (capitalized). Deviating from this in any way
16//! (spaces, case, field order) breaks parity.
17
18use alloc::format;
19use alloc::sync::Arc;
20
21use sha2::{Digest, Sha256};
22
23use crate::errors::CodecError;
24use crate::types::ConfigHash;
25
26/// The complete set of quantization bit widths supported by `TinyQuant`.
27///
28/// Mirrors `tinyquant_cpu.codec.codec_config.SUPPORTED_BIT_WIDTHS`.
29pub const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
30
31/// Maximum supported `dimension` for any codec config or rotation matrix.
32///
33/// `RotationMatrix::build` allocates `O(dim²)` `f64` storage and runs
34/// `O(dim³)` QR. Capping at 16384 bounds peak rotation memory at
35/// ~2 GiB and keeps build time tractable on commodity hardware. The
36/// largest production embedding sizes (e.g. 4096) sit well below this
37/// cap. Phase 28.7 added the cap as defence-in-depth against `DoS` via
38/// the new `RotationMatrix::from_seed_and_dim` `PyO3` entry point.
39pub const MAX_DIMENSION: u32 = 16_384;
40
41/// Immutable configuration snapshot that fully determines codec behavior.
42///
43/// Two configs with identical primary fields are interchangeable. The
44/// cached `config_hash` is computed eagerly in [`CodecConfig::new`] and
45/// ignored by [`PartialEq`] / [`Hash`] so semantically equal configs
46/// compare equal regardless of which instance owns the `Arc<str>`.
47#[derive(Clone, Debug)]
48pub struct CodecConfig {
49    bit_width: u8,
50    seed: u64,
51    dimension: u32,
52    residual_enabled: bool,
53    config_hash: ConfigHash,
54}
55
56impl CodecConfig {
57    /// Validate the field invariants and return a new `CodecConfig`.
58    ///
59    /// # Errors
60    ///
61    /// * [`CodecError::UnsupportedBitWidth`] — `bit_width` is not in
62    ///   [`SUPPORTED_BIT_WIDTHS`].
63    /// * [`CodecError::InvalidDimension`] — `dimension == 0`.
64    /// * [`CodecError::DimensionTooLarge`] — `dimension > MAX_DIMENSION`.
65    pub fn new(
66        bit_width: u8,
67        seed: u64,
68        dimension: u32,
69        residual_enabled: bool,
70    ) -> Result<Self, CodecError> {
71        if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
72            return Err(CodecError::UnsupportedBitWidth { got: bit_width });
73        }
74        if dimension == 0 {
75            return Err(CodecError::InvalidDimension { got: 0 });
76        }
77        if dimension > MAX_DIMENSION {
78            return Err(CodecError::DimensionTooLarge {
79                got: dimension,
80                max: MAX_DIMENSION,
81            });
82        }
83        let config_hash = compute_config_hash(bit_width, seed, dimension, residual_enabled);
84        Ok(Self {
85            bit_width,
86            seed,
87            dimension,
88            residual_enabled,
89            config_hash,
90        })
91    }
92
93    /// The bit width of the quantized indices.
94    #[inline]
95    pub const fn bit_width(&self) -> u8 {
96        self.bit_width
97    }
98
99    /// The seed used for deterministic rotation and codebook generation.
100    #[inline]
101    pub const fn seed(&self) -> u64 {
102        self.seed
103    }
104
105    /// The expected input vector dimensionality.
106    #[inline]
107    pub const fn dimension(&self) -> u32 {
108        self.dimension
109    }
110
111    /// Whether stage-2 residual correction is enabled.
112    #[inline]
113    pub const fn residual_enabled(&self) -> bool {
114        self.residual_enabled
115    }
116
117    /// `2^bit_width` — the number of quantization levels in the codebook.
118    #[inline]
119    pub const fn num_codebook_entries(&self) -> u32 {
120        1u32 << self.bit_width
121    }
122
123    /// Cached SHA-256 hex digest of the canonical string representation.
124    ///
125    /// Returned as an `Arc<str>` borrow; clone with `.clone()` for owned use.
126    #[inline]
127    pub const fn config_hash(&self) -> &ConfigHash {
128        &self.config_hash
129    }
130}
131
132impl PartialEq for CodecConfig {
133    fn eq(&self, other: &Self) -> bool {
134        self.bit_width == other.bit_width
135            && self.seed == other.seed
136            && self.dimension == other.dimension
137            && self.residual_enabled == other.residual_enabled
138    }
139}
140
141impl Eq for CodecConfig {}
142
143impl core::hash::Hash for CodecConfig {
144    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
145        self.bit_width.hash(state);
146        self.seed.hash(state);
147        self.dimension.hash(state);
148        self.residual_enabled.hash(state);
149    }
150}
151
152/// Compute the canonical SHA-256 `config_hash` for the given field tuple.
153///
154/// Exposed at module scope (rather than as an associated function) so that
155/// tests can verify hash computation without constructing a full config.
156/// Kept `pub(crate)` because nothing outside the crate should need to
157/// compute a hash without going through [`CodecConfig::new`].
158pub(crate) fn compute_config_hash(
159    bit_width: u8,
160    seed: u64,
161    dimension: u32,
162    residual_enabled: bool,
163) -> ConfigHash {
164    // CRITICAL: Python bool stringifies as "True" / "False" (capitalized).
165    // See scripts/generate_rust_fixtures.py and the Python reference in
166    // src/tinyquant_cpu/codec/codec_config.py.
167    let canonical = format!(
168        "CodecConfig(bit_width={b},seed={s},dimension={d},residual_enabled={r})",
169        b = bit_width,
170        s = seed,
171        d = dimension,
172        r = if residual_enabled { "True" } else { "False" },
173    );
174    let digest = Sha256::digest(canonical.as_bytes());
175    Arc::from(hex::encode(digest).as_str())
176}