Skip to main content

tinyquant_core/codec/
service.rs

1//! Stateless `Codec` service (Phase 15).
2//!
3//! Pipeline mirrors `tinyquant_cpu.codec.Codec` exactly:
4//!
5//! - **compress**: rotate → quantize → (optional) residual on rotated vs reconstructed
6//! - **decompress**: dequantize → (optional) add residual → inverse rotate
7//!
8//! Phase 26 adds `compress_prepared` / `decompress_prepared_into` which accept a
9//! pre-built [`PreparedCodec`] so the `O(dim²)` rotation factorization is paid
10//! only once per session rather than on every call.
11
12use crate::codec::{
13    codebook::Codebook,
14    codec_config::CodecConfig,
15    compressed_vector::CompressedVector,
16    parallelism::Parallelism,
17    prepared::PreparedCodec,
18    residual::{apply_residual_into, compute_residual},
19    rotation_matrix::RotationMatrix,
20};
21use crate::errors::CodecError;
22use alloc::{vec, vec::Vec};
23
24/// Zero-sized stateless codec service. Mirrors Python `tinyquant_cpu.codec.Codec`.
25#[derive(Default, Debug, Clone, Copy)]
26pub struct Codec;
27
28impl Codec {
29    /// Create a new `Codec` instance (zero allocation).
30    #[must_use]
31    pub const fn new() -> Self {
32        Self
33    }
34
35    /// Compress a single vector. `vector.len()` must equal `config.dimension()`.
36    ///
37    /// Pipeline: rotate → quantize → (optional) residual.
38    ///
39    /// # Errors
40    ///
41    /// - [`CodecError::DimensionMismatch`] if `vector.len() != config.dimension()`
42    /// - [`CodecError::CodebookIncompatible`] if `codebook.bit_width() != config.bit_width()`
43    pub fn compress(
44        &self,
45        vector: &[f32],
46        config: &CodecConfig,
47        codebook: &Codebook,
48    ) -> Result<CompressedVector, CodecError> {
49        let dim = config.dimension() as usize;
50        if vector.len() != dim {
51            // Dimension is bounded by u32 (from CodecConfig); cast is safe in practice.
52            #[allow(clippy::cast_possible_truncation)]
53            let got = vector.len() as u32;
54            return Err(CodecError::DimensionMismatch {
55                expected: config.dimension(),
56                got,
57            });
58        }
59        if codebook.bit_width() != config.bit_width() {
60            return Err(CodecError::CodebookIncompatible {
61                expected: config.bit_width(),
62                got: codebook.bit_width(),
63            });
64        }
65        let rotation = RotationMatrix::from_config(config);
66        Self::compress_with_rotation(vector, config, codebook, &rotation)
67    }
68
69    /// Compress using a pre-built [`PreparedCodec`] (hot path — rotation not rebuilt).
70    ///
71    /// Identical output to [`Self::compress`] for the same inputs.
72    ///
73    /// # Errors
74    ///
75    /// [`CodecError::DimensionMismatch`] if `vector.len() != prepared.config().dimension()`.
76    pub fn compress_prepared(
77        &self,
78        vector: &[f32],
79        prepared: &PreparedCodec,
80    ) -> Result<CompressedVector, CodecError> {
81        let dim = prepared.config().dimension() as usize;
82        if vector.len() != dim {
83            #[allow(clippy::cast_possible_truncation)]
84            let got = vector.len() as u32;
85            return Err(CodecError::DimensionMismatch {
86                expected: prepared.config().dimension(),
87                got,
88            });
89        }
90        Self::compress_with_rotation(
91            vector,
92            prepared.config(),
93            prepared.codebook(),
94            prepared.rotation(),
95        )
96    }
97
98    /// Inner compute path — caller supplies an already-built rotation.
99    ///
100    /// Preconditions (asserted by callers):
101    /// - `vector.len() == config.dimension()`
102    /// - `codebook.bit_width() == config.bit_width()`
103    fn compress_with_rotation(
104        vector: &[f32],
105        config: &CodecConfig,
106        codebook: &Codebook,
107        rotation: &RotationMatrix,
108    ) -> Result<CompressedVector, CodecError> {
109        let dim = config.dimension() as usize;
110        let mut rotated = vec![0.0_f32; dim];
111        rotation.apply_into(vector, &mut rotated)?;
112
113        let mut indices = vec![0_u8; dim];
114        codebook.quantize_into(&rotated, &mut indices)?;
115
116        let residual = if config.residual_enabled() {
117            let mut reconstructed = vec![0.0_f32; dim];
118            codebook.dequantize_into(&indices, &mut reconstructed)?;
119            Some(compute_residual(&rotated, &reconstructed).into_boxed_slice())
120        } else {
121            None
122        };
123
124        CompressedVector::new(
125            indices.into_boxed_slice(),
126            residual,
127            config.config_hash().clone(),
128            config.dimension(),
129            config.bit_width(),
130        )
131    }
132
133    /// Allocating decompress — returns a new `Vec<f32>`.
134    ///
135    /// # Errors
136    ///
137    /// Propagates errors from [`Self::decompress_into`].
138    pub fn decompress(
139        &self,
140        compressed: &CompressedVector,
141        config: &CodecConfig,
142        codebook: &Codebook,
143    ) -> Result<Vec<f32>, CodecError> {
144        let mut out = vec![0.0_f32; config.dimension() as usize];
145        self.decompress_into(compressed, config, codebook, &mut out)?;
146        Ok(out)
147    }
148
149    /// In-place decompress into caller-supplied buffer.
150    ///
151    /// Pipeline: dequantize → (optional) apply residual → inverse rotate.
152    ///
153    /// # Errors
154    ///
155    /// - [`CodecError::ConfigMismatch`] if `compressed.config_hash() != config.config_hash()`
156    /// - [`CodecError::CodebookIncompatible`] on bit-width mismatch
157    /// - [`CodecError::DimensionMismatch`] if `output.len() != config.dimension()`
158    pub fn decompress_into(
159        &self,
160        compressed: &CompressedVector,
161        config: &CodecConfig,
162        codebook: &Codebook,
163        output: &mut [f32],
164    ) -> Result<(), CodecError> {
165        if compressed.config_hash() != config.config_hash() {
166            return Err(CodecError::ConfigMismatch {
167                expected: config.config_hash().clone(),
168                got: compressed.config_hash().clone(),
169            });
170        }
171        if compressed.bit_width() != config.bit_width() {
172            return Err(CodecError::CodebookIncompatible {
173                expected: config.bit_width(),
174                got: compressed.bit_width(),
175            });
176        }
177        if codebook.bit_width() != config.bit_width() {
178            return Err(CodecError::CodebookIncompatible {
179                expected: config.bit_width(),
180                got: codebook.bit_width(),
181            });
182        }
183        if output.len() != config.dimension() as usize {
184            #[allow(clippy::cast_possible_truncation)]
185            let got = output.len() as u32;
186            return Err(CodecError::DimensionMismatch {
187                expected: config.dimension(),
188                got,
189            });
190        }
191        let rotation = RotationMatrix::from_config(config);
192        Self::decompress_into_with_rotation(compressed, codebook, &rotation, output)
193    }
194
195    /// Decompress into a caller-allocated buffer using a pre-built [`PreparedCodec`]
196    /// (hot path — rotation not rebuilt).
197    ///
198    /// Identical output to [`Self::decompress_into`] for the same inputs.
199    ///
200    /// # Errors
201    ///
202    /// - [`CodecError::ConfigMismatch`] if `cv.config_hash() != prepared.config().config_hash()`
203    /// - [`CodecError::DimensionMismatch`] if `out.len() != prepared.config().dimension()`
204    pub fn decompress_prepared_into(
205        &self,
206        cv: &CompressedVector,
207        prepared: &PreparedCodec,
208        out: &mut [f32],
209    ) -> Result<(), CodecError> {
210        if cv.config_hash() != prepared.config().config_hash() {
211            return Err(CodecError::ConfigMismatch {
212                expected: prepared.config().config_hash().clone(),
213                got: cv.config_hash().clone(),
214            });
215        }
216        // cv.bit_width() is not checked separately — the config_hash equality
217        // above already covers all config fields including bit_width.  The
218        // codebook bit_width is validated once at PreparedCodec::new time.
219        if out.len() != prepared.config().dimension() as usize {
220            #[allow(clippy::cast_possible_truncation)]
221            let got = out.len() as u32;
222            return Err(CodecError::DimensionMismatch {
223                expected: prepared.config().dimension(),
224                got,
225            });
226        }
227        Self::decompress_into_with_rotation(cv, prepared.codebook(), prepared.rotation(), out)
228    }
229
230    /// Inner compute path — caller supplies an already-built rotation.
231    ///
232    /// Preconditions (asserted by callers):
233    /// - `output.len() == config dimension`
234    /// - `codebook` is compatible with the compressed vector
235    fn decompress_into_with_rotation(
236        compressed: &CompressedVector,
237        codebook: &Codebook,
238        rotation: &RotationMatrix,
239        output: &mut [f32],
240    ) -> Result<(), CodecError> {
241        let mut rotated = vec![0.0_f32; output.len()];
242        codebook.dequantize_into(compressed.indices(), &mut rotated)?;
243
244        if let Some(residual) = compressed.residual() {
245            apply_residual_into(&mut rotated, residual)?;
246        }
247
248        rotation.apply_inverse_into(&rotated, output)
249    }
250
251    /// Row-major batch compress using the serial strategy.
252    ///
253    /// # Errors
254    ///
255    /// - [`CodecError::DimensionMismatch`] if `cols != config.dimension()`
256    /// - [`CodecError::LengthMismatch`] if `vectors.len() != rows * cols`
257    pub fn compress_batch(
258        &self,
259        vectors: &[f32],
260        rows: usize,
261        cols: usize,
262        config: &CodecConfig,
263        codebook: &Codebook,
264    ) -> Result<Vec<CompressedVector>, CodecError> {
265        self.compress_batch_with(vectors, rows, cols, config, codebook, Parallelism::Serial)
266    }
267
268    /// Row-major batch compress with explicit parallelism strategy.
269    ///
270    /// Phase 21: honours `parallelism`. `Serial` runs the existing single-threaded
271    /// loop; `Custom(driver)` uses the `MaybeUninit + AtomicPtr<CompressedVector>` parallel path in
272    /// `batch.rs` (requires the `std` feature).
273    ///
274    /// The determinism contract guarantees byte-identical output regardless of
275    /// the driver or thread count (see `batch.rs` module doc).
276    ///
277    /// # Errors
278    ///
279    /// Same as [`Self::compress_batch`].
280    pub fn compress_batch_with(
281        &self,
282        vectors: &[f32],
283        rows: usize,
284        cols: usize,
285        config: &CodecConfig,
286        codebook: &Codebook,
287        parallelism: Parallelism,
288    ) -> Result<Vec<CompressedVector>, CodecError> {
289        if cols != config.dimension() as usize {
290            #[allow(clippy::cast_possible_truncation)]
291            let got = cols as u32;
292            return Err(CodecError::DimensionMismatch {
293                expected: config.dimension(),
294                got,
295            });
296        }
297        let expected_len = rows.checked_mul(cols).ok_or(CodecError::LengthMismatch {
298            left: vectors.len(),
299            right: usize::MAX,
300        })?;
301        if vectors.len() != expected_len {
302            return Err(CodecError::LengthMismatch {
303                left: vectors.len(),
304                right: expected_len,
305            });
306        }
307        // Delegate to the parallel batch module when `std` is available.
308        #[cfg(feature = "std")]
309        {
310            crate::codec::batch::compress_batch_parallel(
311                vectors,
312                rows,
313                cols,
314                config,
315                codebook,
316                parallelism,
317            )
318        }
319        // no_std fallback: always serial regardless of `parallelism` argument.
320        #[cfg(not(feature = "std"))]
321        {
322            let _ = parallelism;
323            let mut out = Vec::with_capacity(rows);
324            // Safety: vectors.len() == rows * cols (checked above); slices are in-bounds.
325            #[allow(clippy::indexing_slicing)]
326            for row in 0..rows {
327                let start = row * cols;
328                out.push(self.compress(&vectors[start..start + cols], config, codebook)?);
329            }
330            Ok(out)
331        }
332    }
333
334    /// Batch decompress into a contiguous row-major `output` buffer.
335    ///
336    /// # Errors
337    ///
338    /// - [`CodecError::LengthMismatch`] if `output.len() != compressed.len() * config.dimension()`
339    /// - Propagates per-vector decompress errors.
340    pub fn decompress_batch_into(
341        &self,
342        compressed: &[CompressedVector],
343        config: &CodecConfig,
344        codebook: &Codebook,
345        output: &mut [f32],
346    ) -> Result<(), CodecError> {
347        let cols = config.dimension() as usize;
348        let needed = compressed.len() * cols;
349        if output.len() != needed {
350            return Err(CodecError::LengthMismatch {
351                left: output.len(),
352                right: needed,
353            });
354        }
355        // Safety: output.len() == compressed.len() * cols (checked above); slices are in-bounds.
356        #[allow(clippy::indexing_slicing)]
357        for (row, cv) in compressed.iter().enumerate() {
358            let start = row * cols;
359            self.decompress_into(cv, config, codebook, &mut output[start..start + cols])?;
360        }
361        Ok(())
362    }
363}
364
365/// Minimum batch size below which GPU offload is not attempted.
366///
367/// Below this threshold the host↔device transfer overhead exceeds the
368/// compute savings.  Mirrors `tinyquant_gpu_wgpu::GPU_BATCH_THRESHOLD`.
369pub const GPU_BATCH_THRESHOLD: usize = 512;
370
371/// Trait that every `TinyQuant` GPU compute backend must satisfy.
372///
373/// Implementations of this trait live in external crates (e.g.
374/// `tinyquant-gpu-wgpu`). The trait is defined here so
375/// `tinyquant-core` can express the `compress_batch_gpu_with` method
376/// without a dependency on any concrete GPU crate (which would create
377/// a cyclic crate dependency, since GPU crates already depend on
378/// `tinyquant-core`).
379pub trait GpuComputeBackend {
380    /// The error type returned by this backend's operations.
381    ///
382    /// Must be convertible to [`CodecError`] so
383    /// [`Codec::compress_batch_gpu_with`] can map errors uniformly.
384    type Error: core::fmt::Debug + Into<crate::errors::CodecError>;
385
386    /// Upload `PreparedCodec` buffers to device memory. Idempotent.
387    ///
388    /// # Errors
389    ///
390    /// Returns `Self::Error` if device upload fails.
391    fn prepare_for_device(&mut self, prepared: &mut PreparedCodec) -> Result<(), Self::Error>;
392
393    /// Compress `rows` FP32 vectors of dimension `cols` on the GPU.
394    ///
395    /// # Errors
396    ///
397    /// Returns `Self::Error` if the GPU kernel dispatch or readback fails.
398    fn compress_batch(
399        &mut self,
400        input: &[f32],
401        rows: usize,
402        cols: usize,
403        prepared: &PreparedCodec,
404    ) -> Result<alloc::vec::Vec<CompressedVector>, Self::Error>;
405}
406
407#[cfg(feature = "gpu-wgpu")]
408impl Codec {
409    /// Batch compress with automatic GPU routing.
410    ///
411    /// When `rows >= GPU_BATCH_THRESHOLD` (currently 512), the batch is
412    /// dispatched to `backend` via the [`GpuComputeBackend`] trait.
413    /// Otherwise the call falls through to the CPU path using `parallelism`.
414    ///
415    /// # Caller responsibility
416    ///
417    /// `backend` must have been created and initialized by the caller (which
418    /// may require an async executor, depending on the backend). This method
419    /// only calls synchronous trait methods on the backend.
420    ///
421    /// `prepared` is passed by mutable reference so `prepare_for_device` can
422    /// attach GPU-resident state (idempotent — safe to call before every batch).
423    ///
424    /// # Async-context safety
425    ///
426    /// Do **not** call this function from an async context that owns an active
427    /// tokio or other executor on the same thread. `wgpu` may internally call
428    /// `device.poll(wgpu::Maintain::Wait)`, which blocks the thread.
429    ///
430    /// # Errors
431    ///
432    /// - [`CodecError::GpuUnavailable`] if `prepare_for_device` fails.
433    /// - [`CodecError::GpuError`] if the GPU kernel dispatch or readback fails.
434    /// - All errors from [`Codec::compress_batch_with`] when the CPU fallback
435    ///   is taken.
436    pub fn compress_batch_gpu_with<B>(
437        &self,
438        vectors: &[f32],
439        rows: usize,
440        cols: usize,
441        prepared: &mut PreparedCodec,
442        backend: &mut B,
443        parallelism: Parallelism,
444    ) -> Result<Vec<CompressedVector>, CodecError>
445    where
446        B: GpuComputeBackend,
447    {
448        if rows >= GPU_BATCH_THRESHOLD {
449            backend.prepare_for_device(prepared).map_err(Into::into)?;
450            backend
451                .compress_batch(vectors, rows, cols, prepared)
452                .map_err(Into::into)
453        } else {
454            self.compress_batch_with(
455                vectors,
456                rows,
457                cols,
458                prepared.config(),
459                prepared.codebook(),
460                parallelism,
461            )
462        }
463    }
464}
465
466/// Module-level `compress` free function — mirrors `tinyquant_cpu.codec.compress`.
467///
468/// # Errors
469///
470/// Propagates errors from [`Codec::compress`].
471pub fn compress(
472    vector: &[f32],
473    config: &CodecConfig,
474    codebook: &Codebook,
475) -> Result<CompressedVector, CodecError> {
476    Codec::new().compress(vector, config, codebook)
477}
478
479/// Module-level `decompress` free function — mirrors `tinyquant_cpu.codec.decompress`.
480///
481/// # Errors
482///
483/// Propagates errors from [`Codec::decompress`].
484pub fn decompress(
485    compressed: &CompressedVector,
486    config: &CodecConfig,
487    codebook: &Codebook,
488) -> Result<Vec<f32>, CodecError> {
489    Codec::new().decompress(compressed, config, codebook)
490}