Skip to main content

vortex_tensor/encodings/turboquant/scheme/
compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! TurboQuant encoding (quantization) logic.
5//!
6//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`](crate::vector::Vector)
7//! extension array whose rows are already L2-normalized (unit norm). Normalization is handled
8//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm),
9//! which the [`TurboQuantScheme`](super::TurboQuantScheme) calls before invoking this function.
10
11use vortex_array::ArrayRef;
12use vortex_array::ArrayView;
13use vortex_array::ExecutionCtx;
14use vortex_array::IntoArray;
15use vortex_array::arrays::Extension;
16use vortex_array::arrays::FixedSizeListArray;
17use vortex_array::arrays::PrimitiveArray;
18use vortex_array::arrays::extension::ExtensionArrayExt;
19use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
20use vortex_array::dtype::DType;
21use vortex_array::dtype::Nullability;
22use vortex_array::dtype::PType;
23use vortex_array::validity::Validity;
24use vortex_buffer::BufferMut;
25use vortex_error::VortexExpect;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_ensure;
29use vortex_fastlanes::bitpack_compress::bitpack_encode;
30
31use crate::encodings::turboquant::TurboQuant;
32use crate::encodings::turboquant::array::centroids::compute_centroid_boundaries;
33use crate::encodings::turboquant::array::centroids::find_nearest_centroid;
34use crate::encodings::turboquant::array::centroids::get_centroids;
35use crate::encodings::turboquant::array::rotation::RotationMatrix;
36use crate::encodings::turboquant::vtable::TurboQuantArray;
37use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows;
38
39/// Configuration for TurboQuant encoding.
40#[derive(Clone, Debug)]
41pub struct TurboQuantConfig {
42    /// Bits per coordinate (1-8).
43    pub bit_width: u8,
44    /// Optional seed for the rotation matrix. If None, the default seed is used.
45    pub seed: Option<u64>,
46    /// Number of sign-diagonal + WHT rounds in the structured rotation (default 3).
47    pub num_rounds: u8,
48}
49
50impl Default for TurboQuantConfig {
51    fn default() -> Self {
52        Self {
53            bit_width: TurboQuant::MAX_BIT_WIDTH,
54            seed: Some(42),
55            num_rounds: 3,
56        }
57    }
58}
59
60/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray for quantization.
61///
62/// All quantization (rotation, centroid lookup) happens in f32. f16 is upcast; f64 is truncated.
63fn extract_f32_elements(
64    fsl: &FixedSizeListArray,
65    ctx: &mut ExecutionCtx,
66) -> VortexResult<PrimitiveArray> {
67    let elements = fsl.elements();
68    let primitive = elements.clone().execute::<PrimitiveArray>(ctx)?;
69    let ptype = primitive.ptype();
70
71    match ptype {
72        PType::F16 => Ok(primitive
73            .as_slice::<half::f16>()
74            .iter()
75            .map(|&v| f32::from(v))
76            .collect()),
77        PType::F32 => Ok(primitive),
78        PType::F64 => Ok(primitive
79            .as_slice::<f64>()
80            .iter()
81            .map(|&v| {
82                #[expect(
83                    clippy::cast_possible_truncation,
84                    reason = "TurboQuant quantization operates in f32, so f64 inputs are intentionally downcast"
85                )]
86                let v = v as f32;
87                v
88            })
89            .collect()),
90        _ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"),
91    }
92}
93
94/// Shared intermediate results from the quantization loop.
95struct QuantizationResult {
96    rotation: RotationMatrix,
97    centroids: Vec<f32>,
98    all_indices: BufferMut<u8>,
99    padded_dim: usize,
100}
101
102/// Core quantization: rotate and quantize already-normalized rows.
103///
104/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null
105/// vectors are not supported and must be zeroed out before reaching this function. The rotation
106/// and centroid lookup happen in f32.
107fn turboquant_quantize_core(
108    fsl: &FixedSizeListArray,
109    seed: u64,
110    bit_width: u8,
111    num_rounds: u8,
112    ctx: &mut ExecutionCtx,
113) -> VortexResult<QuantizationResult> {
114    let dimension =
115        usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize");
116    let num_rows = fsl.len();
117
118    let rotation = RotationMatrix::try_new(seed, dimension, num_rounds as usize)?;
119    let padded_dim = rotation.padded_dim();
120    let padded_dim_u32 =
121        u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
122
123    let f32_elements = extract_f32_elements(fsl, ctx)?;
124
125    let centroids = get_centroids(padded_dim_u32, bit_width)?;
126    let boundaries = compute_centroid_boundaries(&centroids);
127
128    let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
129    let mut padded = vec![0.0f32; padded_dim];
130    let mut rotated = vec![0.0f32; padded_dim];
131
132    let f32_slice = f32_elements.as_slice::<f32>();
133    for row in 0..num_rows {
134        let x = &f32_slice[row * dimension..(row + 1) * dimension];
135
136        // Zero-pad to the next power of 2.
137        padded[..dimension].copy_from_slice(x);
138        padded[dimension..].fill(0.0);
139
140        rotation.rotate(&padded, &mut rotated);
141
142        for j in 0..padded_dim {
143            all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
144        }
145    }
146
147    Ok(QuantizationResult {
148        rotation,
149        centroids,
150        all_indices,
151        padded_dim,
152    })
153}
154
155/// Build a `TurboQuantArray` from quantization results.
156///
157/// The `ext_dtype` must be a non-nullable [`Vector`](crate::vector::Vector) extension dtype.
158fn build_turboquant(
159    num_rows: usize,
160    core: QuantizationResult,
161    ext_dtype: &DType,
162) -> VortexResult<TurboQuantArray> {
163    let padded_dim = core.padded_dim;
164    let padded_dim_u32 =
165        u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
166    let codes_elements =
167        PrimitiveArray::new::<u8>(core.all_indices.freeze(), Validity::NonNullable);
168    let codes = FixedSizeListArray::try_new(
169        codes_elements.into_array(),
170        padded_dim_u32,
171        Validity::NonNullable,
172        num_rows,
173    )?
174    .into_array();
175
176    // TODO(perf): `get_centroids` returns Vec<f32>; could avoid the copy by
177    // supporting Buffer::from(Vec<T>) or caching as Buffer directly.
178    let mut centroids_buf = BufferMut::<f32>::with_capacity(core.centroids.len());
179    centroids_buf.extend_from_slice(&core.centroids);
180    let centroids_array =
181        PrimitiveArray::new::<f32>(centroids_buf.freeze(), Validity::NonNullable).into_array();
182
183    let rotation_signs = bitpack_rotation_signs(&core.rotation)?;
184
185    TurboQuant::try_new_array(ext_dtype.clone(), codes, centroids_array, rotation_signs)
186}
187
188/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a
189/// [`TurboQuantArray`].
190///
191/// The input must be a non-nullable Vector extension array whose rows are already unit-norm.
192/// **Null vectors are not supported.** The caller must normalize and strip nullability before
193/// calling this function, for example via [`normalize_as_l2_denorm`].
194///
195/// This function validates that every row is L2-normalized (or is exactly 0.0). Use
196/// [`turboquant_encode_unchecked`] to skip this check when the caller has just performed
197/// normalization.
198///
199/// The returned array is a plain [`TurboQuantArray`] that decompresses to unit-norm vectors.
200/// The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray if the original
201/// magnitudes need to be restored.
202///
203/// [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm
204/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm
205pub fn turboquant_encode(
206    ext: ArrayView<Extension>,
207    config: &TurboQuantConfig,
208    ctx: &mut ExecutionCtx,
209) -> VortexResult<ArrayRef> {
210    let ext_dtype = ext.dtype().clone();
211
212    vortex_ensure!(
213        !ext_dtype.is_nullable(),
214        "TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}",
215    );
216
217    validate_l2_normalized_rows(ext.as_ref().clone(), ctx)?;
218
219    // SAFETY: We just validated that the input is non-nullable and all rows are unit-norm.
220    unsafe { turboquant_encode_unchecked(ext, config, ctx) }
221}
222
223/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a
224/// [`TurboQuantArray`], without validating the unit-norm precondition.
225///
226/// # Safety
227///
228/// The caller must ensure:
229///
230/// - The input dtype is non-nullable.
231/// - Every row is L2-normalized (unit norm) or is a zero vector.
232///
233/// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently
234/// incorrect quantization results.
235pub unsafe fn turboquant_encode_unchecked(
236    ext: ArrayView<Extension>,
237    config: &TurboQuantConfig,
238    ctx: &mut ExecutionCtx,
239) -> VortexResult<ArrayRef> {
240    let ext_dtype = ext.dtype().clone();
241    let storage = ext.storage_array();
242    let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
243
244    vortex_ensure!(
245        config.bit_width >= 1 && config.bit_width <= TurboQuant::MAX_BIT_WIDTH,
246        "bit_width must be 1-{}, got {}",
247        TurboQuant::MAX_BIT_WIDTH,
248        config.bit_width
249    );
250    let dimension = fsl.list_size();
251    vortex_ensure!(
252        dimension >= TurboQuant::MIN_DIMENSION,
253        "TurboQuant requires dimension >= {}, got {dimension}",
254        TurboQuant::MIN_DIMENSION
255    );
256
257    if fsl.is_empty() {
258        let padded_dim = dimension.next_power_of_two();
259        let empty_codes = FixedSizeListArray::try_new(
260            PrimitiveArray::empty::<u8>(Nullability::NonNullable).into_array(),
261            padded_dim,
262            Validity::NonNullable,
263            0,
264        )?;
265
266        let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
267        let empty_signs = FixedSizeListArray::try_new(
268            PrimitiveArray::empty::<u8>(Nullability::NonNullable).into_array(),
269            padded_dim,
270            Validity::NonNullable,
271            0,
272        )?;
273
274        return Ok(TurboQuant::try_new_array(
275            ext_dtype,
276            empty_codes.into_array(),
277            empty_centroids.into_array(),
278            empty_signs.into_array(),
279        )?
280        .into_array());
281    }
282
283    let seed = config.seed.unwrap_or(42);
284    let num_rows = fsl.len();
285    let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
286
287    Ok(build_turboquant(num_rows, core, &ext_dtype)?.into_array())
288}
289
290/// Export rotation signs as a `FixedSizeListArray` wrapping a 1-bit [`BitPackedArray`].
291///
292/// The rotation matrix's `num_rounds * padded_dim` sign values are exported as 0/1 u8 values in
293/// inverse application order, bitpacked to 1 bit per sign, then wrapped in a
294/// `FixedSizeListArray` with `list_size = padded_dim` and `len = num_rounds`.
295fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult<ArrayRef> {
296    let signs_u8 = rotation.export_inverse_signs_u8();
297    let num_rounds = rotation.num_rounds();
298    let padded_dim = u32::try_from(rotation.padded_dim()).vortex_expect("padded_dim fits in u32");
299
300    let mut buf = BufferMut::<u8>::with_capacity(signs_u8.len());
301    buf.extend_from_slice(&signs_u8);
302    let prim = PrimitiveArray::new::<u8>(buf.freeze(), Validity::NonNullable);
303    let bitpacked = bitpack_encode(&prim, 1, None)?;
304
305    let fsl = FixedSizeListArray::try_new(
306        bitpacked.into_array(),
307        padded_dim,
308        Validity::NonNullable,
309        num_rounds,
310    )?;
311    Ok(fsl.into_array())
312}