Skip to main content

vortex_tensor/encodings/turboquant/array/
data.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::sync::Arc;
7
8use vortex_array::ArrayRef;
9use vortex_array::TypedArrayRef;
10use vortex_array::dtype::DType;
11use vortex_array::dtype::Nullability;
12use vortex_array::dtype::PType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_ensure_eq;
17
18use crate::encodings::turboquant::array::slots::Slot;
19use crate::encodings::turboquant::vtable::TurboQuant;
20
21/// TurboQuant array data.
22///
23/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector)
24/// extension arrays. It stores quantized coordinate codes for unit-norm vectors, along with shared
25/// codebook centroids and the parameters of the current structured rotation.
26///
27/// Norms should be stored externally in the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)
28/// `ScalarFnArray` wrapper.
29///
30/// See the [module docs](crate::encodings::turboquant) for algorithmic details.
31///
32/// Note that degenerate TurboQuant arrays have zero rows and `bit_width == 0`, with all slots
33/// empty.
34#[derive(Clone, Debug)]
35pub struct TurboQuantData {
36    /// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size.
37    ///
38    /// Stored as a convenience field to avoid repeatedly extracting it from `dtype`.
39    pub(crate) dimension: u32,
40
41    /// The number of bits per coordinate (0-8), derived from `log2(centroids.len())`.
42    ///
43    /// This is 0 for degenerate empty arrays.
44    pub(crate) bit_width: u8,
45
46    /// The number of sign-diagonal + WHT rounds in the structured rotation.
47    ///
48    /// This is 0 for degenerate empty arrays.
49    pub(crate) num_rounds: u8,
50}
51
52impl Display for TurboQuantData {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        write!(
55            f,
56            "dimension: {}, bit_width: {}, num_rounds: {}",
57            self.dimension, self.bit_width, self.num_rounds
58        )
59    }
60}
61
62impl TurboQuantData {
63    /// Build a `TurboQuantData` with validation.
64    ///
65    /// # Errors
66    ///
67    /// Returns an error if:
68    /// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
69    /// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH).
70    pub fn try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> VortexResult<Self> {
71        vortex_ensure!(
72            dimension >= TurboQuant::MIN_DIMENSION,
73            "TurboQuant requires dimension >= {}, got {dimension}",
74            TurboQuant::MIN_DIMENSION
75        );
76        vortex_ensure!(
77            bit_width <= TurboQuant::MAX_BIT_WIDTH,
78            "bit_width is expected to be between 0 and {}, got {bit_width}",
79            TurboQuant::MAX_BIT_WIDTH
80        );
81
82        Ok(Self {
83            dimension,
84            bit_width,
85            num_rounds,
86        })
87    }
88
89    /// Build a `TurboQuantData` without validation.
90    ///
91    /// # Safety
92    ///
93    /// The caller must ensure:
94    ///
95    /// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
96    /// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`.
97    /// - `num_rounds` is >= 1 (or 0 for degenerate empty arrays).
98    ///
99    /// Violating these invariants may produce incorrect results during decompression.
100    pub unsafe fn new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self {
101        Self {
102            dimension,
103            bit_width,
104            num_rounds,
105        }
106    }
107
108    /// Validates the components that would be used to create a `TurboQuantData`.
109    ///
110    /// This function checks all the invariants required by [`new_unchecked`](Self::new_unchecked).
111    pub fn validate(
112        dtype: &DType,
113        codes: &ArrayRef,
114        centroids: &ArrayRef,
115        rotation_signs: &ArrayRef,
116    ) -> VortexResult<()> {
117        let vector_metadata = TurboQuant::validate_dtype(dtype)?;
118        let dimension = vector_metadata.dimensions();
119        let padded_dim = dimension.next_power_of_two();
120
121        // TurboQuant arrays are always non-nullable. Nullability should be handled by the external
122        // L2Denorm ScalarFnArray wrapper.
123        vortex_ensure!(
124            !dtype.is_nullable(),
125            "TurboQuant dtype must be non-nullable, got {dtype}",
126        );
127
128        // Codes must be a non-nullable FixedSizeList<u8> with list_size == padded_dim.
129        let expected_codes_dtype = DType::FixedSizeList(
130            Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
131            padded_dim,
132            Nullability::NonNullable,
133        );
134        vortex_ensure_eq!(
135            *codes.dtype(),
136            expected_codes_dtype,
137            "codes dtype does not match expected {expected_codes_dtype}",
138        );
139
140        // Centroids are always f32 regardless of element type.
141        let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
142        vortex_ensure_eq!(
143            *centroids.dtype(),
144            centroids_dtype,
145            "centroids dtype must be non-nullable f32",
146        );
147
148        // Rotation signs must be a FixedSizeList<u8> with list_size == padded_dim. The FSL length
149        // is the number of rotation rounds.
150        let expected_signs_dtype = DType::FixedSizeList(
151            Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
152            padded_dim,
153            Nullability::NonNullable,
154        );
155        vortex_ensure_eq!(
156            *rotation_signs.dtype(),
157            expected_signs_dtype,
158            "rotation_signs dtype does not match expected {expected_signs_dtype}",
159        );
160        // Degenerate (empty) case: all children must be empty, and bit_width is 0.
161        let num_rows = codes.len();
162        if num_rows == 0 {
163            vortex_ensure!(
164                centroids.is_empty(),
165                "degenerate TurboQuant must have empty centroids, got length {}",
166                centroids.len()
167            );
168            vortex_ensure!(
169                rotation_signs.is_empty(),
170                "degenerate TurboQuant must have empty rotation_signs, got length {}",
171                rotation_signs.len()
172            );
173            return Ok(());
174        }
175
176        vortex_ensure!(
177            !rotation_signs.is_empty(),
178            "rotation_signs must have at least 1 round"
179        );
180
181        // Non-degenerate: derive and validate bit_width from centroids.
182        let num_centroids = centroids.len();
183        vortex_ensure!(
184            num_centroids.is_power_of_two()
185                && (2..=TurboQuant::MAX_CENTROIDS).contains(&num_centroids),
186            "centroids length must be a power of 2 in [2, {}], got {num_centroids}",
187            TurboQuant::MAX_CENTROIDS
188        );
189
190        #[expect(
191            clippy::cast_possible_truncation,
192            reason = "Guaranteed to be [1,8] by the preceding power-of-2 and range checks."
193        )]
194        let bit_width = num_centroids.trailing_zeros() as u8;
195        vortex_ensure!(
196            (1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width),
197            "derived bit_width must be 1-{}, got {bit_width}",
198            TurboQuant::MAX_BIT_WIDTH
199        );
200
201        Ok(())
202    }
203
204    pub(crate) fn make_slots(
205        codes: ArrayRef,
206        centroids: ArrayRef,
207        rotation_signs: ArrayRef,
208    ) -> Vec<Option<ArrayRef>> {
209        let mut slots = vec![None; Slot::COUNT];
210        slots[Slot::Codes as usize] = Some(codes);
211        slots[Slot::Centroids as usize] = Some(centroids);
212        slots[Slot::RotationSigns as usize] = Some(rotation_signs);
213        slots
214    }
215
216    /// The vector dimension `d`, as stored in the [`Vector`](crate::vector::Vector) extension
217    /// dtype's `FixedSizeList` storage.
218    pub fn dimension(&self) -> u32 {
219        self.dimension
220    }
221
222    /// MSE bits per coordinate (1-MAX_BIT_WIDTH for non-empty arrays, 0 for degenerate empty arrays).
223    pub fn bit_width(&self) -> u8 {
224        self.bit_width
225    }
226
227    /// The number of sign-diagonal + WHT rounds in the structured rotation.
228    pub fn num_rounds(&self) -> u8 {
229        self.num_rounds
230    }
231
232    /// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)).
233    ///
234    /// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so
235    /// non-power-of-2 dimensions are zero-padded to this value.
236    pub fn padded_dim(&self) -> u32 {
237        self.dimension.next_power_of_two()
238    }
239}
240
241pub trait TurboQuantArrayExt: TypedArrayRef<TurboQuant> {
242    fn codes(&self) -> &ArrayRef {
243        self.as_ref().slots()[Slot::Codes as usize]
244            .as_ref()
245            .vortex_expect("TurboQuantArray codes slot")
246    }
247
248    fn centroids(&self) -> &ArrayRef {
249        self.as_ref().slots()[Slot::Centroids as usize]
250            .as_ref()
251            .vortex_expect("TurboQuantArray centroids slot")
252    }
253
254    fn rotation_signs(&self) -> &ArrayRef {
255        self.as_ref().slots()[Slot::RotationSigns as usize]
256            .as_ref()
257            .vortex_expect("TurboQuantArray rotation_signs slot")
258    }
259}
260
261impl<T: TypedArrayRef<TurboQuant>> TurboQuantArrayExt for T {}