Skip to main content

tinyquant_core/codec/
service.rs

1//! Stateless `Codec` service (Phase 15).
2//!
3//! Pipeline mirrors `tinyquant_cpu.codec.Codec` exactly:
4//!
5//! - **compress**: rotate → quantize → (optional) residual on rotated vs reconstructed
6//! - **decompress**: dequantize → (optional) add residual → inverse rotate
7
8use crate::codec::{
9    codebook::Codebook,
10    codec_config::CodecConfig,
11    compressed_vector::CompressedVector,
12    parallelism::Parallelism,
13    residual::{apply_residual_into, compute_residual},
14    rotation_matrix::RotationMatrix,
15};
16use crate::errors::CodecError;
17use alloc::{vec, vec::Vec};
18
19/// Zero-sized stateless codec service. Mirrors Python `tinyquant_cpu.codec.Codec`.
20#[derive(Default, Debug, Clone, Copy)]
21pub struct Codec;
22
23impl Codec {
24    /// Create a new `Codec` instance (zero allocation).
25    #[must_use]
26    pub const fn new() -> Self {
27        Self
28    }
29
30    /// Compress a single vector. `vector.len()` must equal `config.dimension()`.
31    ///
32    /// Pipeline: rotate → quantize → (optional) residual.
33    ///
34    /// # Errors
35    ///
36    /// - [`CodecError::DimensionMismatch`] if `vector.len() != config.dimension()`
37    /// - [`CodecError::CodebookIncompatible`] if `codebook.bit_width() != config.bit_width()`
38    pub fn compress(
39        &self,
40        vector: &[f32],
41        config: &CodecConfig,
42        codebook: &Codebook,
43    ) -> Result<CompressedVector, CodecError> {
44        let dim = config.dimension() as usize;
45        if vector.len() != dim {
46            // Dimension is bounded by u32 (from CodecConfig); cast is safe in practice.
47            #[allow(clippy::cast_possible_truncation)]
48            let got = vector.len() as u32;
49            return Err(CodecError::DimensionMismatch {
50                expected: config.dimension(),
51                got,
52            });
53        }
54        if codebook.bit_width() != config.bit_width() {
55            return Err(CodecError::CodebookIncompatible {
56                expected: config.bit_width(),
57                got: codebook.bit_width(),
58            });
59        }
60
61        let rotation = RotationMatrix::from_config(config);
62        let mut rotated = vec![0.0_f32; dim];
63        rotation.apply_into(vector, &mut rotated)?;
64
65        let mut indices = vec![0_u8; dim];
66        codebook.quantize_into(&rotated, &mut indices)?;
67
68        let residual = if config.residual_enabled() {
69            let mut reconstructed = vec![0.0_f32; dim];
70            codebook.dequantize_into(&indices, &mut reconstructed)?;
71            Some(compute_residual(&rotated, &reconstructed).into_boxed_slice())
72        } else {
73            None
74        };
75
76        CompressedVector::new(
77            indices.into_boxed_slice(),
78            residual,
79            config.config_hash().clone(),
80            config.dimension(),
81            config.bit_width(),
82        )
83    }
84
85    /// Allocating decompress — returns a new `Vec<f32>`.
86    ///
87    /// # Errors
88    ///
89    /// Propagates errors from [`Self::decompress_into`].
90    pub fn decompress(
91        &self,
92        compressed: &CompressedVector,
93        config: &CodecConfig,
94        codebook: &Codebook,
95    ) -> Result<Vec<f32>, CodecError> {
96        let mut out = vec![0.0_f32; config.dimension() as usize];
97        self.decompress_into(compressed, config, codebook, &mut out)?;
98        Ok(out)
99    }
100
101    /// In-place decompress into caller-supplied buffer.
102    ///
103    /// Pipeline: dequantize → (optional) apply residual → inverse rotate.
104    ///
105    /// # Errors
106    ///
107    /// - [`CodecError::ConfigMismatch`] if `compressed.config_hash() != config.config_hash()`
108    /// - [`CodecError::CodebookIncompatible`] on bit-width mismatch
109    /// - [`CodecError::DimensionMismatch`] if `output.len() != config.dimension()`
110    pub fn decompress_into(
111        &self,
112        compressed: &CompressedVector,
113        config: &CodecConfig,
114        codebook: &Codebook,
115        output: &mut [f32],
116    ) -> Result<(), CodecError> {
117        if compressed.config_hash() != config.config_hash() {
118            return Err(CodecError::ConfigMismatch {
119                expected: config.config_hash().clone(),
120                got: compressed.config_hash().clone(),
121            });
122        }
123        if compressed.bit_width() != config.bit_width() {
124            return Err(CodecError::CodebookIncompatible {
125                expected: config.bit_width(),
126                got: compressed.bit_width(),
127            });
128        }
129        if codebook.bit_width() != config.bit_width() {
130            return Err(CodecError::CodebookIncompatible {
131                expected: config.bit_width(),
132                got: codebook.bit_width(),
133            });
134        }
135        if output.len() != config.dimension() as usize {
136            #[allow(clippy::cast_possible_truncation)]
137            let got = output.len() as u32;
138            return Err(CodecError::DimensionMismatch {
139                expected: config.dimension(),
140                got,
141            });
142        }
143
144        let mut rotated = vec![0.0_f32; output.len()];
145        codebook.dequantize_into(compressed.indices(), &mut rotated)?;
146
147        if let Some(residual) = compressed.residual() {
148            apply_residual_into(&mut rotated, residual)?;
149        }
150
151        let rotation = RotationMatrix::from_config(config);
152        rotation.apply_inverse_into(&rotated, output)
153    }
154
155    /// Row-major batch compress using the serial strategy.
156    ///
157    /// # Errors
158    ///
159    /// - [`CodecError::DimensionMismatch`] if `cols != config.dimension()`
160    /// - [`CodecError::LengthMismatch`] if `vectors.len() != rows * cols`
161    pub fn compress_batch(
162        &self,
163        vectors: &[f32],
164        rows: usize,
165        cols: usize,
166        config: &CodecConfig,
167        codebook: &Codebook,
168    ) -> Result<Vec<CompressedVector>, CodecError> {
169        self.compress_batch_with(vectors, rows, cols, config, codebook, Parallelism::Serial)
170    }
171
172    /// Row-major batch compress with explicit parallelism strategy.
173    ///
174    /// Phase 21: honours `parallelism`. `Serial` runs the existing single-threaded
175    /// loop; `Custom(driver)` uses the `MaybeUninit + AtomicPtr<CompressedVector>` parallel path in
176    /// `batch.rs` (requires the `std` feature).
177    ///
178    /// The determinism contract guarantees byte-identical output regardless of
179    /// the driver or thread count (see `batch.rs` module doc).
180    ///
181    /// # Errors
182    ///
183    /// Same as [`Self::compress_batch`].
184    pub fn compress_batch_with(
185        &self,
186        vectors: &[f32],
187        rows: usize,
188        cols: usize,
189        config: &CodecConfig,
190        codebook: &Codebook,
191        parallelism: Parallelism,
192    ) -> Result<Vec<CompressedVector>, CodecError> {
193        if cols != config.dimension() as usize {
194            #[allow(clippy::cast_possible_truncation)]
195            let got = cols as u32;
196            return Err(CodecError::DimensionMismatch {
197                expected: config.dimension(),
198                got,
199            });
200        }
201        let expected_len = rows.checked_mul(cols).ok_or(CodecError::LengthMismatch {
202            left: vectors.len(),
203            right: usize::MAX,
204        })?;
205        if vectors.len() != expected_len {
206            return Err(CodecError::LengthMismatch {
207                left: vectors.len(),
208                right: expected_len,
209            });
210        }
211        // Delegate to the parallel batch module when `std` is available.
212        #[cfg(feature = "std")]
213        {
214            crate::codec::batch::compress_batch_parallel(
215                vectors,
216                rows,
217                cols,
218                config,
219                codebook,
220                parallelism,
221            )
222        }
223        // no_std fallback: always serial regardless of `parallelism` argument.
224        #[cfg(not(feature = "std"))]
225        {
226            let _ = parallelism;
227            let mut out = Vec::with_capacity(rows);
228            // Safety: vectors.len() == rows * cols (checked above); slices are in-bounds.
229            #[allow(clippy::indexing_slicing)]
230            for row in 0..rows {
231                let start = row * cols;
232                out.push(self.compress(&vectors[start..start + cols], config, codebook)?);
233            }
234            Ok(out)
235        }
236    }
237
238    /// Batch decompress into a contiguous row-major `output` buffer.
239    ///
240    /// # Errors
241    ///
242    /// - [`CodecError::LengthMismatch`] if `output.len() != compressed.len() * config.dimension()`
243    /// - Propagates per-vector decompress errors.
244    pub fn decompress_batch_into(
245        &self,
246        compressed: &[CompressedVector],
247        config: &CodecConfig,
248        codebook: &Codebook,
249        output: &mut [f32],
250    ) -> Result<(), CodecError> {
251        let cols = config.dimension() as usize;
252        let needed = compressed.len() * cols;
253        if output.len() != needed {
254            return Err(CodecError::LengthMismatch {
255                left: output.len(),
256                right: needed,
257            });
258        }
259        // Safety: output.len() == compressed.len() * cols (checked above); slices are in-bounds.
260        #[allow(clippy::indexing_slicing)]
261        for (row, cv) in compressed.iter().enumerate() {
262            let start = row * cols;
263            self.decompress_into(cv, config, codebook, &mut output[start..start + cols])?;
264        }
265        Ok(())
266    }
267}
268
269/// Module-level `compress` free function — mirrors `tinyquant_cpu.codec.compress`.
270///
271/// # Errors
272///
273/// Propagates errors from [`Codec::compress`].
274pub fn compress(
275    vector: &[f32],
276    config: &CodecConfig,
277    codebook: &Codebook,
278) -> Result<CompressedVector, CodecError> {
279    Codec::new().compress(vector, config, codebook)
280}
281
282/// Module-level `decompress` free function — mirrors `tinyquant_cpu.codec.decompress`.
283///
284/// # Errors
285///
286/// Propagates errors from [`Codec::decompress`].
287pub fn decompress(
288    compressed: &CompressedVector,
289    config: &CodecConfig,
290    codebook: &Codebook,
291) -> Result<Vec<f32>, CodecError> {
292    Codec::new().decompress(compressed, config, codebook)
293}