Skip to main content

vortex_tensor/encodings/turboquant/
compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! TurboQuant encoding (quantization) logic.
5//!
6//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`](crate::vector::Vector)
7//! extension array whose rows are already L2-normalized (unit norm). Normalization is handled
8//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm),
9//! which the [`TurboQuantScheme`] calls before invoking this function.
10//!
11//! [`TurboQuantScheme`]: crate::encodings::turboquant::TurboQuantScheme
12
13use vortex_array::ArrayRef;
14use vortex_array::ArrayView;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::arrays::Extension;
18use vortex_array::arrays::FixedSizeListArray;
19use vortex_array::arrays::PrimitiveArray;
20use vortex_array::arrays::dict::DictArray;
21use vortex_array::arrays::extension::ExtensionArrayExt;
22use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
23use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
24use vortex_array::dtype::Nullability;
25use vortex_array::validity::Validity;
26use vortex_buffer::Buffer;
27use vortex_buffer::BufferMut;
28use vortex_error::VortexExpect;
29use vortex_error::VortexResult;
30use vortex_error::vortex_ensure;
31
32use crate::encodings::turboquant::MAX_BIT_WIDTH;
33use crate::encodings::turboquant::MIN_DIMENSION;
34use crate::encodings::turboquant::centroids::compute_centroid_boundaries;
35use crate::encodings::turboquant::centroids::compute_or_get_centroids;
36use crate::encodings::turboquant::centroids::find_nearest_centroid;
37use crate::scalar_fns::l2_denorm::L2Denorm;
38use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
39use crate::scalar_fns::sorf_transform::SorfMatrix;
40use crate::scalar_fns::sorf_transform::SorfOptions;
41use crate::scalar_fns::sorf_transform::SorfTransform;
42use crate::types::vector::AnyVector;
43use crate::types::vector::Vector;
44use crate::utils::cast_to_f32;
45
46/// Configuration for TurboQuant encoding.
47#[derive(Clone, Debug)]
48pub struct TurboQuantConfig {
49    /// Bits per coordinate (1-8).
50    pub bit_width: u8,
51    /// Seed for the rotation matrix.
52    pub seed: u64,
53    /// Number of sign-diagonal + WHT rounds in the structured rotation (default 3).
54    pub num_rounds: u8,
55}
56
57impl Default for TurboQuantConfig {
58    fn default() -> Self {
59        Self {
60            bit_width: MAX_BIT_WIDTH,
61            seed: 42,
62            num_rounds: 3,
63        }
64    }
65}
66
67/// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector)
68/// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized
69/// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer
70/// [`L2Denorm`] wrapper.
71///
72/// The returned array has the canonical TurboQuant shape:
73///
74/// ```text
75/// ScalarFnArray(L2Denorm, [
76///     ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]),
77///     norms,
78/// ])
79/// ```
80///
81/// # Errors
82///
83/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or
84/// if [`turboquant_encode_unchecked`] rejects the input shape.
85pub fn turboquant_encode(
86    input: ArrayRef,
87    config: &TurboQuantConfig,
88    ctx: &mut ExecutionCtx,
89) -> VortexResult<ArrayRef> {
90    // We must normalize the array before we can encode it with TurboQuant.
91    let l2_denorm = normalize_as_l2_denorm(input, ctx)?;
92    let normalized = l2_denorm.child_at(0).clone();
93    let norms = l2_denorm.child_at(1).clone();
94    let num_rows = l2_denorm.len();
95
96    let normalized_ext = normalized
97        .as_opt::<Extension>()
98        .vortex_expect("normalize_as_l2_denorm always produces an Extension array child");
99
100    // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero for null rows).
101    let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?;
102
103    // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally
104    // bypass the strict normalized-row validation when reattaching the stored norms.
105    Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
106}
107
108/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a
109/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm
110/// precondition.
111///
112/// # Safety
113///
114/// The caller must ensure:
115///
116/// - The input dtype is non-nullable.
117/// - Every row is L2-normalized (unit norm) or is a zero vector.
118///
119/// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently
120/// incorrect quantization results.
121pub unsafe fn turboquant_encode_unchecked(
122    ext: ArrayView<Extension>,
123    config: &TurboQuantConfig,
124    ctx: &mut ExecutionCtx,
125) -> VortexResult<ArrayRef> {
126    let ext_dtype = ext.dtype().clone();
127    let storage = ext.storage_array();
128    let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
129
130    vortex_ensure!(
131        config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH,
132        "bit_width must be 1-{MAX_BIT_WIDTH}, got {}",
133        config.bit_width
134    );
135    let dimension = fsl.list_size();
136    vortex_ensure!(
137        dimension >= MIN_DIMENSION,
138        "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}",
139    );
140
141    let vector_metadata = ext_dtype.as_extension().metadata::<AnyVector>();
142    let element_ptype = vector_metadata.element_ptype();
143
144    let seed = config.seed;
145    let num_rows = fsl.len();
146
147    if fsl.is_empty() {
148        let padded_dim = dimension.next_power_of_two();
149        let empty_codes = PrimitiveArray::empty::<u8>(Nullability::NonNullable);
150        let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
151        let empty_dict =
152            DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?;
153        let empty_fsl = FixedSizeListArray::try_new(
154            empty_dict.into_array(),
155            padded_dim,
156            Validity::NonNullable,
157            0,
158        )?;
159        let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?;
160
161        let sorf_options = SorfOptions {
162            seed,
163            num_rounds: config.num_rounds,
164            dimensions: dimension,
165            element_ptype,
166        };
167        return Ok(
168            SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(),
169        );
170    }
171
172    let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
173    let quantized_fsl =
174        build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?;
175    let padded_vector = Vector::try_new_vector_array(quantized_fsl)?;
176
177    let sorf_options = SorfOptions {
178        seed,
179        num_rounds: config.num_rounds,
180        dimensions: dimension,
181        element_ptype,
182    };
183    Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
184}
185
186/// Shared intermediate results from the quantization loop.
187struct QuantizationResult {
188    centroids: Buffer<f32>,
189    all_indices: Buffer<u8>,
190    padded_dim: usize,
191}
192
193/// Core quantization: rotate and quantize already-normalized rows.
194///
195/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null
196/// vectors are not supported and must be zeroed out before reaching this function. The rotation
197/// and centroid lookup happen in f32.
198fn turboquant_quantize_core(
199    fsl: &FixedSizeListArray,
200    seed: u64,
201    bit_width: u8,
202    num_rounds: u8,
203    ctx: &mut ExecutionCtx,
204) -> VortexResult<QuantizationResult> {
205    let dimension = fsl.list_size() as usize;
206    let num_rows = fsl.len();
207
208    let padded_dim = dimension.next_power_of_two();
209    let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?;
210    let padded_dim_u32 =
211        u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
212
213    let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
214    let f32_elements = cast_to_f32(elements_prim)?;
215
216    let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?;
217    let boundaries = compute_centroid_boundaries(&centroids);
218
219    let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
220    let mut padded = vec![0.0f32; padded_dim];
221    let mut rotated = vec![0.0f32; padded_dim];
222
223    let f32_slice = f32_elements.as_slice();
224    for row in 0..num_rows {
225        let x = &f32_slice[row * dimension..(row + 1) * dimension];
226
227        // Zero-pad to the next power of 2.
228        padded[..dimension].copy_from_slice(x);
229        padded[dimension..].fill(0.0);
230
231        rotation.rotate(&padded, &mut rotated);
232
233        for j in 0..padded_dim {
234            all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
235        }
236    }
237
238    Ok(QuantizationResult {
239        centroids,
240        all_indices: all_indices.freeze(),
241        padded_dim,
242    })
243}
244
245/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`.
246///
247/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes indexes into the
248/// centroid codebook. The Dict can be independently sliced, taken, or executed (dequantized)
249/// without knowledge of the rotation.
250fn build_quantized_fsl(
251    num_rows: usize,
252    all_indices: Buffer<u8>,
253    centroids: Buffer<f32>,
254    padded_dim: usize,
255) -> VortexResult<ArrayRef> {
256    let codes = PrimitiveArray::new::<u8>(all_indices, Validity::NonNullable);
257    let centroids_array = PrimitiveArray::new::<f32>(centroids, Validity::NonNullable);
258
259    let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?;
260
261    let padded_dim_u32 =
262        u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
263    Ok(FixedSizeListArray::try_new(
264        dict.into_array(),
265        padded_dim_u32,
266        Validity::NonNullable,
267        num_rows,
268    )?
269    .into_array())
270}