Skip to main content

vortex_tensor/encodings/turboquant/
compress.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! TurboQuant encoding (quantization) logic.
5//!
6//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`](crate::vector::Vector)
7//! extension array whose rows are already L2-normalized (unit norm). Normalization is handled
8//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm),
9//! which the [`TurboQuantScheme`] calls before invoking this function.
10//!
11//! [`TurboQuantScheme`]: crate::encodings::turboquant::TurboQuantScheme
12
13use vortex_array::ArrayRef;
14use vortex_array::ArrayView;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::arrays::Extension;
18use vortex_array::arrays::ExtensionArray;
19use vortex_array::arrays::FixedSizeListArray;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::arrays::dict::DictArray;
22use vortex_array::arrays::extension::ExtensionArrayExt;
23use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
24use vortex_array::dtype::Nullability;
25use vortex_array::dtype::extension::ExtDType;
26use vortex_array::extension::EmptyMetadata;
27use vortex_array::validity::Validity;
28use vortex_buffer::BufferMut;
29use vortex_error::VortexExpect;
30use vortex_error::VortexResult;
31use vortex_error::vortex_ensure;
32
33use crate::encodings::turboquant::MAX_BIT_WIDTH;
34use crate::encodings::turboquant::MIN_DIMENSION;
35use crate::encodings::turboquant::centroids::compute_centroid_boundaries;
36use crate::encodings::turboquant::centroids::find_nearest_centroid;
37use crate::encodings::turboquant::centroids::get_centroids;
38use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows;
39use crate::scalar_fns::sorf_transform::SorfMatrix;
40use crate::scalar_fns::sorf_transform::SorfOptions;
41use crate::scalar_fns::sorf_transform::SorfTransform;
42use crate::utils::cast_to_f32;
43use crate::vector::AnyVector;
44use crate::vector::Vector;
45
46/// Configuration for TurboQuant encoding.
47#[derive(Clone, Debug)]
48pub struct TurboQuantConfig {
49    /// Bits per coordinate (1-8).
50    pub bit_width: u8,
51    /// Optional seed for the rotation matrix. If None, the default seed is used.
52    pub seed: Option<u64>,
53    /// Number of sign-diagonal + WHT rounds in the structured rotation (default 3).
54    pub num_rounds: u8,
55}
56
57impl Default for TurboQuantConfig {
58    fn default() -> Self {
59        Self {
60            bit_width: MAX_BIT_WIDTH,
61            seed: Some(42),
62            num_rounds: 3,
63        }
64    }
65}
66
67/// Shared intermediate results from the quantization loop.
68struct QuantizationResult {
69    centroids: Vec<f32>,
70    all_indices: BufferMut<u8>,
71    padded_dim: usize,
72}
73
74/// Core quantization: rotate and quantize already-normalized rows.
75///
76/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null
77/// vectors are not supported and must be zeroed out before reaching this function. The rotation
78/// and centroid lookup happen in f32.
79fn turboquant_quantize_core(
80    fsl: &FixedSizeListArray,
81    seed: u64,
82    bit_width: u8,
83    num_rounds: u8,
84    ctx: &mut ExecutionCtx,
85) -> VortexResult<QuantizationResult> {
86    let dimension =
87        usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize");
88    let num_rows = fsl.len();
89
90    let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?;
91    let padded_dim = rotation.padded_dim();
92    let padded_dim_u32 =
93        u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
94
95    let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
96    let f32_elements = cast_to_f32(elements_prim)?;
97
98    let centroids = get_centroids(padded_dim_u32, bit_width)?;
99    let boundaries = compute_centroid_boundaries(&centroids);
100
101    let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
102    let mut padded = vec![0.0f32; padded_dim];
103    let mut rotated = vec![0.0f32; padded_dim];
104
105    let f32_slice = f32_elements.as_slice();
106    for row in 0..num_rows {
107        let x = &f32_slice[row * dimension..(row + 1) * dimension];
108
109        // Zero-pad to the next power of 2.
110        padded[..dimension].copy_from_slice(x);
111        padded[dimension..].fill(0.0);
112
113        rotation.rotate(&padded, &mut rotated);
114
115        for j in 0..padded_dim {
116            all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
117        }
118    }
119
120    Ok(QuantizationResult {
121        centroids,
122        all_indices,
123        padded_dim,
124    })
125}
126
127/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`.
128///
129/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes
130/// indexes into the centroid codebook. The Dict can be independently sliced, taken,
131/// or executed (dequantized) without knowledge of the rotation.
132fn build_quantized_fsl(
133    num_rows: usize,
134    all_indices: BufferMut<u8>,
135    centroids: &[f32],
136    padded_dim: usize,
137) -> VortexResult<ArrayRef> {
138    let codes = PrimitiveArray::new::<u8>(all_indices.freeze(), Validity::NonNullable);
139
140    let mut centroids_buf = BufferMut::<f32>::with_capacity(centroids.len());
141    centroids_buf.extend_from_slice(centroids);
142    let centroids_array = PrimitiveArray::new::<f32>(centroids_buf.freeze(), Validity::NonNullable);
143
144    let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?;
145
146    let padded_dim_u32 =
147        u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
148    Ok(FixedSizeListArray::try_new(
149        dict.into_array(),
150        padded_dim_u32,
151        Validity::NonNullable,
152        num_rows,
153    )?
154    .into_array())
155}
156
157/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a
158/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`.
159///
160/// The input must be a non-nullable Vector extension array whose rows are already unit-norm.
161/// **Null vectors are not supported.** The caller must normalize and strip nullability before
162/// calling this function, for example via [`normalize_as_l2_denorm`].
163///
164/// This function validates that every row is L2-normalized (or is exactly 0.0). Use
165/// [`turboquant_encode_unchecked`] to skip this check when the caller has just performed
166/// normalization.
167///
168/// The returned array is a `SorfTransform` ScalarFnArray wrapping `FSL(Dict)` that decompresses
169/// to unit-norm vectors. The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray
170/// if the original magnitudes need to be restored.
171///
172/// [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm
173/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm
174pub fn turboquant_encode(
175    ext: ArrayView<Extension>,
176    config: &TurboQuantConfig,
177    ctx: &mut ExecutionCtx,
178) -> VortexResult<ArrayRef> {
179    let ext_dtype = ext.dtype().clone();
180
181    vortex_ensure!(
182        !ext_dtype.is_nullable(),
183        "TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}",
184    );
185
186    validate_l2_normalized_rows(ext.as_ref(), ctx)?;
187
188    // SAFETY: We just validated that the input is non-nullable and all rows are unit-norm.
189    unsafe { turboquant_encode_unchecked(ext, config, ctx) }
190}
191
192/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a
193/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm
194/// precondition.
195///
196/// # Safety
197///
198/// The caller must ensure:
199///
200/// - The input dtype is non-nullable.
201/// - Every row is L2-normalized (unit norm) or is a zero vector.
202///
203/// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently
204/// incorrect quantization results.
205pub unsafe fn turboquant_encode_unchecked(
206    ext: ArrayView<Extension>,
207    config: &TurboQuantConfig,
208    ctx: &mut ExecutionCtx,
209) -> VortexResult<ArrayRef> {
210    let ext_dtype = ext.dtype().clone();
211    let storage = ext.storage_array();
212    let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
213
214    vortex_ensure!(
215        config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH,
216        "bit_width must be 1-{MAX_BIT_WIDTH}, got {}",
217        config.bit_width
218    );
219    let dimension = fsl.list_size();
220    vortex_ensure!(
221        dimension >= MIN_DIMENSION,
222        "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}",
223    );
224
225    let vector_metadata = ext_dtype.as_extension().metadata::<AnyVector>();
226    let element_ptype = vector_metadata.element_ptype();
227
228    let seed = config.seed.unwrap_or(42);
229    let num_rows = fsl.len();
230
231    if fsl.is_empty() {
232        let padded_dim = dimension.next_power_of_two();
233        let empty_codes = PrimitiveArray::empty::<u8>(Nullability::NonNullable);
234        let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
235        let empty_dict =
236            DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?;
237        let empty_fsl = FixedSizeListArray::try_new(
238            empty_dict.into_array(),
239            padded_dim,
240            Validity::NonNullable,
241            0,
242        )?;
243        let empty_padded_vector = wrap_padded_as_vector(empty_fsl.into_array())?;
244
245        let sorf_options = SorfOptions {
246            seed,
247            num_rounds: config.num_rounds,
248            dimension,
249            element_ptype,
250        };
251        return Ok(
252            SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(),
253        );
254    }
255
256    let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
257    let quantized_fsl =
258        build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?;
259    let padded_vector = wrap_padded_as_vector(quantized_fsl)?;
260
261    let sorf_options = SorfOptions {
262        seed,
263        num_rounds: config.num_rounds,
264        dimension,
265        element_ptype,
266    };
267    Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
268}
269
270/// Wrap an `FSL<f32, padded_dim>` in a [`Vector`](crate::vector::Vector) extension so it can be
271/// passed as the child of [`SorfTransform`], which expects a `Vector<padded_dim>` input.
272fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult<ArrayRef> {
273    let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
274    Ok(ExtensionArray::new(ext_dtype, fsl).into_array())
275}