Skip to main content

tinyquant_core/codec/
residual.rs

1//! FP16 residual helpers (Phase 15).
2//!
3//! `compute_residual(original, reconstructed) = (original - reconstructed)` stored
4//! as little-endian IEEE 754 binary16, matching Python's
5//! `(original - reconstructed).astype(np.float16).tobytes()`.
6//!
7//! `apply_residual_into` adds the decoded residual back into a reconstruction buffer
8//! in place.
9
10use crate::errors::CodecError;
11use alloc::vec::Vec;
12use half::f16;
13
14/// Compute the per-element residual `original - reconstructed` as little-endian f16 bytes.
15///
16/// The output has length `original.len() * 2`. Caller guarantees `original.len() ==
17/// reconstructed.len()`.
18#[must_use]
19pub fn compute_residual(original: &[f32], reconstructed: &[f32]) -> Vec<u8> {
20    debug_assert_eq!(original.len(), reconstructed.len());
21    let mut out = Vec::with_capacity(original.len() * 2);
22    for (o, r) in original.iter().zip(reconstructed.iter()) {
23        let diff = f16::from_f32(*o - *r);
24        out.extend_from_slice(&diff.to_le_bytes());
25    }
26    out
27}
28
29/// Decode a residual byte-slice as f16 LE values and add them in place into `values`.
30///
31/// Returns [`CodecError::LengthMismatch`] if `residual.len() != values.len() * 2`.
32///
33/// # Errors
34///
35/// Returns [`CodecError::LengthMismatch`] when the residual buffer has the wrong length.
36pub fn apply_residual_into(values: &mut [f32], residual: &[u8]) -> Result<(), CodecError> {
37    let expected = values.len() * 2;
38    if residual.len() != expected {
39        return Err(CodecError::LengthMismatch {
40            left: residual.len(),
41            right: expected,
42        });
43    }
44    // chunks_exact(2) guarantees 2-byte chunks; zip with values guarantees
45    // in-bounds access since residual.len() == values.len() * 2 (checked above).
46    #[allow(clippy::indexing_slicing)]
47    for (v, chunk) in values.iter_mut().zip(residual.chunks_exact(2)) {
48        let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
49        *v += f16::from_bits(bits).to_f32();
50    }
51    Ok(())
52}