Skip to main content

vortex_tensor/encodings/turboquant/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! VTable implementation for TurboQuant encoding.
5
6use std::hash::Hash;
7use std::hash::Hasher;
8use std::sync::Arc;
9
10use prost::Message;
11use vortex_array::Array;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayId;
15use vortex_array::ArrayParts;
16use vortex_array::ArrayRef;
17use vortex_array::ArrayView;
18use vortex_array::ExecutionCtx;
19use vortex_array::ExecutionResult;
20use vortex_array::Precision;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::dtype::PType;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::validity::Validity;
27use vortex_array::vtable::VTable;
28use vortex_array::vtable::ValidityVTable;
29use vortex_error::VortexExpect;
30use vortex_error::VortexResult;
31use vortex_error::vortex_ensure;
32use vortex_error::vortex_ensure_eq;
33use vortex_error::vortex_err;
34use vortex_error::vortex_panic;
35use vortex_session::VortexSession;
36
37use crate::encodings::turboquant::TurboQuantData;
38use crate::encodings::turboquant::array::slots::Slot;
39use crate::encodings::turboquant::compute::rules::PARENT_KERNELS;
40use crate::encodings::turboquant::compute::rules::RULES;
41use crate::encodings::turboquant::metadata::TurboQuantMetadata;
42use crate::encodings::turboquant::scheme::decompress::execute_decompress;
43use crate::vector::AnyVector;
44use crate::vector::VectorMatcherMetadata;
45
46/// Encoding marker type for TurboQuant.
47#[derive(Clone, Debug)]
48pub struct TurboQuant;
49
50impl TurboQuant {
51    pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant");
52
53    /// Minimum vector dimension for TurboQuant encoding.
54    ///
55    /// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total
56    /// amount of distortion.
57    pub const MIN_DIMENSION: u32 = 128;
58
59    /// Maximum supported number of bits per quantized coordinate.
60    pub const MAX_BIT_WIDTH: u8 = 8;
61
62    /// Maximum supported number of centroids in the scalar quantizer codebook.
63    pub const MAX_CENTROIDS: usize = 1usize << (Self::MAX_BIT_WIDTH as usize);
64
65    /// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with
66    /// dimension >= [`MIN_DIMENSION`](Self::MIN_DIMENSION).
67    ///
68    /// Returns the validated vector metadata on success.
69    pub fn validate_dtype(dtype: &DType) -> VortexResult<VectorMatcherMetadata> {
70        let vector_metadata = dtype
71            .as_extension_opt()
72            .and_then(|ext| ext.metadata_opt::<AnyVector>())
73            .ok_or_else(|| {
74                vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}")
75            })?;
76
77        let dimensions = vector_metadata.dimensions();
78        vortex_ensure!(
79            dimensions >= Self::MIN_DIMENSION,
80            "TurboQuant requires dimension >= {}, got {dimensions}",
81            Self::MIN_DIMENSION
82        );
83
84        Ok(vector_metadata)
85    }
86
87    /// Creates a new [`TurboQuantArray`].
88    ///
89    /// The `dtype` must be a non-nullable [`Vector`](crate::vector::Vector) extension type.
90    /// Nullability is handled externally by the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)
91    /// ScalarFnArray wrapper.
92    ///
93    /// Internally calls [`TurboQuantData::validate`] and [`TurboQuantData::try_new`], then
94    /// delegates to [`new_array_unchecked`](Self::new_array_unchecked).
95    pub fn try_new_array(
96        dtype: DType,
97        codes: ArrayRef,
98        centroids: ArrayRef,
99        rotation_signs: ArrayRef,
100    ) -> VortexResult<TurboQuantArray> {
101        TurboQuantData::validate(&dtype, &codes, &centroids, &rotation_signs)?;
102
103        Ok(unsafe { Self::new_array_unchecked(dtype, codes, centroids, rotation_signs) })
104    }
105
106    /// Creates a new [`TurboQuantArray`] without validation.
107    ///
108    /// # Safety
109    ///
110    /// The caller must ensure all invariants required by [`TurboQuantData::validate`] hold:
111    ///
112    /// - `dtype` is a non-nullable [`Vector`](crate::vector::Vector) extension type with
113    ///   dimension >= [`MIN_DIMENSION`](Self::MIN_DIMENSION).
114    /// - `codes` is a non-nullable `FixedSizeList<u8>` with `list_size == padded_dim`.
115    /// - `centroids` is a non-nullable `Primitive<f32>` with a power-of-2 length in
116    ///   `[2, MAX_CENTROIDS]` (or empty for degenerate arrays).
117    /// - `rotation_signs` is a non-nullable `FixedSizeList<u8>` with `list_size == padded_dim`.
118    ///
119    /// Violating these invariants may produce incorrect results during decompression or panics
120    /// during array access.
121    pub unsafe fn new_array_unchecked(
122        dtype: DType,
123        codes: ArrayRef,
124        centroids: ArrayRef,
125        rotation_signs: ArrayRef,
126    ) -> TurboQuantArray {
127        #[cfg(debug_assertions)]
128        TurboQuantData::validate(&dtype, &codes, &centroids, &rotation_signs)
129            .vortex_expect("[DEBUG ASSERTION]: TurboQuantData arrays are invalid");
130
131        let len = codes.len();
132
133        let dimension = dtype
134            .as_extension_opt()
135            .vortex_expect("we validated the dtype")
136            .metadata_opt::<AnyVector>()
137            .vortex_expect("we validated that this is a vector")
138            .dimensions();
139
140        let bit_width = if centroids.is_empty() {
141            0
142        } else {
143            #[expect(
144                clippy::cast_possible_truncation,
145                reason = "bit_width is guaranteed <= 8"
146            )]
147            (centroids.len().trailing_zeros() as u8)
148        };
149
150        #[expect(
151            clippy::cast_possible_truncation,
152            reason = "num_rounds fits in u8 by the caller's invariants"
153        )]
154        let num_rounds = rotation_signs.len() as u8;
155
156        // SAFETY: The caller guarantees that dimension, bit_width, and num_rounds satisfy the
157        // invariants documented on `TurboQuantData::new_unchecked`.
158        let data = unsafe { TurboQuantData::new_unchecked(dimension, bit_width, num_rounds) };
159        let parts = ArrayParts::new(TurboQuant, dtype, len, data)
160            .with_slots(TurboQuantData::make_slots(codes, centroids, rotation_signs));
161
162        // SAFETY: The caller guarantees the parts are logically consistent.
163        unsafe { Array::from_parts_unchecked(parts) }
164    }
165}
166
167/// A [`TurboQuant`]-encoded Vortex array.
168pub type TurboQuantArray = Array<TurboQuant>;
169
170impl VTable for TurboQuant {
171    type ArrayData = TurboQuantData;
172    type OperationsVTable = TurboQuant;
173    type ValidityVTable = TurboQuant;
174
175    fn id(&self) -> ArrayId {
176        Self::ID
177    }
178
179    fn validate(
180        &self,
181        data: &Self::ArrayData,
182        dtype: &DType,
183        len: usize,
184        slots: &[Option<ArrayRef>],
185    ) -> VortexResult<()> {
186        vortex_ensure_eq!(
187            slots.len(),
188            Slot::COUNT,
189            "TurboQuantArray got incorrect amount of slots",
190        );
191
192        // Even if the array is degenerate (empty), the arrays still have to exist
193        // (they will be empty).
194        let codes = slots[Slot::Codes as usize]
195            .as_ref()
196            .ok_or_else(|| vortex_err!("TurboQuantArray missing codes slot"))?;
197        let centroids = slots[Slot::Centroids as usize]
198            .as_ref()
199            .ok_or_else(|| vortex_err!("TurboQuantArray missing centroids slot"))?;
200        let rotation_signs = slots[Slot::RotationSigns as usize]
201            .as_ref()
202            .ok_or_else(|| vortex_err!("TurboQuantArray missing rotation_signs slot"))?;
203
204        vortex_ensure_eq!(
205            codes.len(),
206            len,
207            "TurboQuant codes length does not match outer length",
208        );
209
210        TurboQuantData::validate(dtype, codes, centroids, rotation_signs)?;
211
212        vortex_ensure_eq!(data.dimension, Self::validate_dtype(dtype)?.dimensions());
213
214        let expected_bit_width = if centroids.is_empty() {
215            0
216        } else {
217            u8::try_from(centroids.len().trailing_zeros())
218                .map_err(|_| vortex_err!("centroids bit_width does not fit in u8"))?
219        };
220        vortex_ensure_eq!(
221            data.bit_width,
222            expected_bit_width,
223            "TurboQuant bit_width does not match centroids slot",
224        );
225
226        // Verify num_rounds matches the rotation_signs FSL length.
227        let expected_num_rounds = u8::try_from(rotation_signs.len())
228            .map_err(|_| vortex_err!("rotation_signs num_rounds does not fit in u8"))?;
229        vortex_ensure_eq!(
230            data.num_rounds,
231            expected_num_rounds,
232            "TurboQuant num_rounds does not match rotation_signs slot",
233        );
234
235        Ok(())
236    }
237
238    fn nbuffers(_array: ArrayView<Self>) -> usize {
239        0
240    }
241
242    fn buffer(_array: ArrayView<Self>, idx: usize) -> BufferHandle {
243        vortex_panic!("TurboQuantArray buffer index {idx} out of bounds")
244    }
245
246    fn buffer_name(_array: ArrayView<Self>, _idx: usize) -> Option<String> {
247        None
248    }
249
250    fn serialize(
251        array: ArrayView<'_, Self>,
252        _session: &VortexSession,
253    ) -> VortexResult<Option<Vec<u8>>> {
254        Ok(Some(
255            TurboQuantMetadata::new(array.bit_width, array.num_rounds).encode_to_vec(),
256        ))
257    }
258
259    fn deserialize(
260        &self,
261        dtype: &DType,
262        len: usize,
263        metadata: &[u8],
264        _buffers: &[BufferHandle],
265        children: &dyn ArrayChildren,
266        _session: &VortexSession,
267    ) -> VortexResult<ArrayParts<Self>> {
268        let metadata = TurboQuantMetadata::decode(metadata)?;
269        let bit_width = metadata.bit_width()?;
270        let num_rounds = metadata.num_rounds()?;
271
272        // bit_width == 0 and num_rounds == 0 are only valid for degenerate (empty) arrays.
273        vortex_ensure!(
274            bit_width > 0 || len == 0,
275            "bit_width == 0 is only valid for empty arrays, got len={len}"
276        );
277        vortex_ensure!(
278            num_rounds > 0 || len == 0,
279            "num_rounds == 0 is only valid for empty arrays, got len={len}"
280        );
281
282        // Validate and derive dimension from the Vector extension dtype.
283        let vector_metadata = TurboQuant::validate_dtype(dtype)?;
284        let dimensions = vector_metadata.dimensions();
285
286        // TurboQuant arrays are always non-nullable.
287        vortex_ensure!(
288            !dtype.is_nullable(),
289            "TurboQuant dtype must be non-nullable during deserialization"
290        );
291
292        let padded_dim = dimensions.next_power_of_two();
293
294        // Get the codes array (indices into the codebook). Codes are always non-nullable.
295        let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable);
296        let codes_dtype =
297            DType::FixedSizeList(Arc::new(codes_ptype), padded_dim, Nullability::NonNullable);
298        let codes_array = children.get(0, &codes_dtype, len)?;
299
300        // Get the centroids array (codebook).
301        let num_centroids = if bit_width == 0 {
302            0 // A degenerate TQ array.
303        } else {
304            1usize << bit_width
305        };
306        let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
307        let centroids = children.get(1, &centroids_dtype, num_centroids)?;
308
309        // Get the rotation signs array (FixedSizeList<u8> with list_size = padded_dim).
310        let signs_len = if len == 0 { 0 } else { num_rounds as usize };
311        let signs_dtype = DType::FixedSizeList(
312            Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
313            padded_dim,
314            Nullability::NonNullable,
315        );
316        let rotation_signs = children.get(2, &signs_dtype, signs_len)?;
317
318        Ok(ArrayParts::new(
319            TurboQuant,
320            dtype.clone(),
321            len,
322            TurboQuantData {
323                dimension: dimensions,
324                bit_width,
325                num_rounds,
326            },
327        )
328        .with_slots(TurboQuantData::make_slots(
329            codes_array,
330            centroids,
331            rotation_signs,
332        )))
333    }
334
335    fn slot_name(_array: ArrayView<Self>, idx: usize) -> String {
336        Slot::from_index(idx).name().to_string()
337    }
338    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
339        Ok(ExecutionResult::done(execute_decompress(array, ctx)?))
340    }
341
342    fn execute_parent(
343        array: ArrayView<Self>,
344        parent: &ArrayRef,
345        child_idx: usize,
346        ctx: &mut ExecutionCtx,
347    ) -> VortexResult<Option<ArrayRef>> {
348        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
349    }
350
351    fn reduce_parent(
352        array: ArrayView<Self>,
353        parent: &ArrayRef,
354        child_idx: usize,
355    ) -> VortexResult<Option<ArrayRef>> {
356        RULES.evaluate(array, parent, child_idx)
357    }
358}
359
360impl ValidityVTable<TurboQuant> for TurboQuant {
361    fn validity(_array: ArrayView<'_, TurboQuant>) -> VortexResult<Validity> {
362        // TurboQuant arrays are always non-nullable. This method is only called when the dtype is
363        // nullable, which should never happen for TQ arrays.
364        Ok(Validity::NonNullable)
365    }
366}
367
368impl ArrayHash for TurboQuantData {
369    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
370        self.dimension.hash(state);
371        self.bit_width.hash(state);
372        self.num_rounds.hash(state);
373    }
374}
375
376impl ArrayEq for TurboQuantData {
377    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
378        self.dimension == other.dimension
379            && self.bit_width == other.bit_width
380            && self.num_rounds == other.num_rounds
381    }
382}