vortex_tensor/encodings/turboquant/
vtable.rs1use std::hash::Hash;
7use std::hash::Hasher;
8use std::sync::Arc;
9
10use prost::Message;
11use vortex_array::Array;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayId;
15use vortex_array::ArrayParts;
16use vortex_array::ArrayRef;
17use vortex_array::ArrayView;
18use vortex_array::ExecutionCtx;
19use vortex_array::ExecutionResult;
20use vortex_array::Precision;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::dtype::PType;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::validity::Validity;
27use vortex_array::vtable::VTable;
28use vortex_array::vtable::ValidityVTable;
29use vortex_error::VortexExpect;
30use vortex_error::VortexResult;
31use vortex_error::vortex_ensure;
32use vortex_error::vortex_ensure_eq;
33use vortex_error::vortex_err;
34use vortex_error::vortex_panic;
35use vortex_session::VortexSession;
36
37use crate::encodings::turboquant::TurboQuantData;
38use crate::encodings::turboquant::array::slots::Slot;
39use crate::encodings::turboquant::compute::rules::PARENT_KERNELS;
40use crate::encodings::turboquant::compute::rules::RULES;
41use crate::encodings::turboquant::metadata::TurboQuantMetadata;
42use crate::encodings::turboquant::scheme::decompress::execute_decompress;
43use crate::vector::AnyVector;
44use crate::vector::VectorMatcherMetadata;
45
46#[derive(Clone, Debug)]
48pub struct TurboQuant;
49
50impl TurboQuant {
51 pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant");
52
53 pub const MIN_DIMENSION: u32 = 128;
58
59 pub const MAX_BIT_WIDTH: u8 = 8;
61
62 pub const MAX_CENTROIDS: usize = 1usize << (Self::MAX_BIT_WIDTH as usize);
64
65 pub fn validate_dtype(dtype: &DType) -> VortexResult<VectorMatcherMetadata> {
70 let vector_metadata = dtype
71 .as_extension_opt()
72 .and_then(|ext| ext.metadata_opt::<AnyVector>())
73 .ok_or_else(|| {
74 vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}")
75 })?;
76
77 let dimensions = vector_metadata.dimensions();
78 vortex_ensure!(
79 dimensions >= Self::MIN_DIMENSION,
80 "TurboQuant requires dimension >= {}, got {dimensions}",
81 Self::MIN_DIMENSION
82 );
83
84 Ok(vector_metadata)
85 }
86
87 pub fn try_new_array(
96 dtype: DType,
97 codes: ArrayRef,
98 centroids: ArrayRef,
99 rotation_signs: ArrayRef,
100 ) -> VortexResult<TurboQuantArray> {
101 TurboQuantData::validate(&dtype, &codes, ¢roids, &rotation_signs)?;
102
103 Ok(unsafe { Self::new_array_unchecked(dtype, codes, centroids, rotation_signs) })
104 }
105
106 pub unsafe fn new_array_unchecked(
122 dtype: DType,
123 codes: ArrayRef,
124 centroids: ArrayRef,
125 rotation_signs: ArrayRef,
126 ) -> TurboQuantArray {
127 #[cfg(debug_assertions)]
128 TurboQuantData::validate(&dtype, &codes, ¢roids, &rotation_signs)
129 .vortex_expect("[DEBUG ASSERTION]: TurboQuantData arrays are invalid");
130
131 let len = codes.len();
132
133 let dimension = dtype
134 .as_extension_opt()
135 .vortex_expect("we validated the dtype")
136 .metadata_opt::<AnyVector>()
137 .vortex_expect("we validated that this is a vector")
138 .dimensions();
139
140 let bit_width = if centroids.is_empty() {
141 0
142 } else {
143 #[expect(
144 clippy::cast_possible_truncation,
145 reason = "bit_width is guaranteed <= 8"
146 )]
147 (centroids.len().trailing_zeros() as u8)
148 };
149
150 #[expect(
151 clippy::cast_possible_truncation,
152 reason = "num_rounds fits in u8 by the caller's invariants"
153 )]
154 let num_rounds = rotation_signs.len() as u8;
155
156 let data = unsafe { TurboQuantData::new_unchecked(dimension, bit_width, num_rounds) };
159 let parts = ArrayParts::new(TurboQuant, dtype, len, data)
160 .with_slots(TurboQuantData::make_slots(codes, centroids, rotation_signs));
161
162 unsafe { Array::from_parts_unchecked(parts) }
164 }
165}
166
167pub type TurboQuantArray = Array<TurboQuant>;
169
170impl VTable for TurboQuant {
171 type ArrayData = TurboQuantData;
172 type OperationsVTable = TurboQuant;
173 type ValidityVTable = TurboQuant;
174
175 fn id(&self) -> ArrayId {
176 Self::ID
177 }
178
179 fn validate(
180 &self,
181 data: &Self::ArrayData,
182 dtype: &DType,
183 len: usize,
184 slots: &[Option<ArrayRef>],
185 ) -> VortexResult<()> {
186 vortex_ensure_eq!(
187 slots.len(),
188 Slot::COUNT,
189 "TurboQuantArray got incorrect amount of slots",
190 );
191
192 let codes = slots[Slot::Codes as usize]
195 .as_ref()
196 .ok_or_else(|| vortex_err!("TurboQuantArray missing codes slot"))?;
197 let centroids = slots[Slot::Centroids as usize]
198 .as_ref()
199 .ok_or_else(|| vortex_err!("TurboQuantArray missing centroids slot"))?;
200 let rotation_signs = slots[Slot::RotationSigns as usize]
201 .as_ref()
202 .ok_or_else(|| vortex_err!("TurboQuantArray missing rotation_signs slot"))?;
203
204 vortex_ensure_eq!(
205 codes.len(),
206 len,
207 "TurboQuant codes length does not match outer length",
208 );
209
210 TurboQuantData::validate(dtype, codes, centroids, rotation_signs)?;
211
212 vortex_ensure_eq!(data.dimension, Self::validate_dtype(dtype)?.dimensions());
213
214 let expected_bit_width = if centroids.is_empty() {
215 0
216 } else {
217 u8::try_from(centroids.len().trailing_zeros())
218 .map_err(|_| vortex_err!("centroids bit_width does not fit in u8"))?
219 };
220 vortex_ensure_eq!(
221 data.bit_width,
222 expected_bit_width,
223 "TurboQuant bit_width does not match centroids slot",
224 );
225
226 let expected_num_rounds = u8::try_from(rotation_signs.len())
228 .map_err(|_| vortex_err!("rotation_signs num_rounds does not fit in u8"))?;
229 vortex_ensure_eq!(
230 data.num_rounds,
231 expected_num_rounds,
232 "TurboQuant num_rounds does not match rotation_signs slot",
233 );
234
235 Ok(())
236 }
237
238 fn nbuffers(_array: ArrayView<Self>) -> usize {
239 0
240 }
241
242 fn buffer(_array: ArrayView<Self>, idx: usize) -> BufferHandle {
243 vortex_panic!("TurboQuantArray buffer index {idx} out of bounds")
244 }
245
246 fn buffer_name(_array: ArrayView<Self>, _idx: usize) -> Option<String> {
247 None
248 }
249
250 fn serialize(
251 array: ArrayView<'_, Self>,
252 _session: &VortexSession,
253 ) -> VortexResult<Option<Vec<u8>>> {
254 Ok(Some(
255 TurboQuantMetadata::new(array.bit_width, array.num_rounds).encode_to_vec(),
256 ))
257 }
258
259 fn deserialize(
260 &self,
261 dtype: &DType,
262 len: usize,
263 metadata: &[u8],
264 _buffers: &[BufferHandle],
265 children: &dyn ArrayChildren,
266 _session: &VortexSession,
267 ) -> VortexResult<ArrayParts<Self>> {
268 let metadata = TurboQuantMetadata::decode(metadata)?;
269 let bit_width = metadata.bit_width()?;
270 let num_rounds = metadata.num_rounds()?;
271
272 vortex_ensure!(
274 bit_width > 0 || len == 0,
275 "bit_width == 0 is only valid for empty arrays, got len={len}"
276 );
277 vortex_ensure!(
278 num_rounds > 0 || len == 0,
279 "num_rounds == 0 is only valid for empty arrays, got len={len}"
280 );
281
282 let vector_metadata = TurboQuant::validate_dtype(dtype)?;
284 let dimensions = vector_metadata.dimensions();
285
286 vortex_ensure!(
288 !dtype.is_nullable(),
289 "TurboQuant dtype must be non-nullable during deserialization"
290 );
291
292 let padded_dim = dimensions.next_power_of_two();
293
294 let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable);
296 let codes_dtype =
297 DType::FixedSizeList(Arc::new(codes_ptype), padded_dim, Nullability::NonNullable);
298 let codes_array = children.get(0, &codes_dtype, len)?;
299
300 let num_centroids = if bit_width == 0 {
302 0 } else {
304 1usize << bit_width
305 };
306 let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
307 let centroids = children.get(1, ¢roids_dtype, num_centroids)?;
308
309 let signs_len = if len == 0 { 0 } else { num_rounds as usize };
311 let signs_dtype = DType::FixedSizeList(
312 Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
313 padded_dim,
314 Nullability::NonNullable,
315 );
316 let rotation_signs = children.get(2, &signs_dtype, signs_len)?;
317
318 Ok(ArrayParts::new(
319 TurboQuant,
320 dtype.clone(),
321 len,
322 TurboQuantData {
323 dimension: dimensions,
324 bit_width,
325 num_rounds,
326 },
327 )
328 .with_slots(TurboQuantData::make_slots(
329 codes_array,
330 centroids,
331 rotation_signs,
332 )))
333 }
334
335 fn slot_name(_array: ArrayView<Self>, idx: usize) -> String {
336 Slot::from_index(idx).name().to_string()
337 }
338 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
339 Ok(ExecutionResult::done(execute_decompress(array, ctx)?))
340 }
341
342 fn execute_parent(
343 array: ArrayView<Self>,
344 parent: &ArrayRef,
345 child_idx: usize,
346 ctx: &mut ExecutionCtx,
347 ) -> VortexResult<Option<ArrayRef>> {
348 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
349 }
350
351 fn reduce_parent(
352 array: ArrayView<Self>,
353 parent: &ArrayRef,
354 child_idx: usize,
355 ) -> VortexResult<Option<ArrayRef>> {
356 RULES.evaluate(array, parent, child_idx)
357 }
358}
359
360impl ValidityVTable<TurboQuant> for TurboQuant {
361 fn validity(_array: ArrayView<'_, TurboQuant>) -> VortexResult<Validity> {
362 Ok(Validity::NonNullable)
365 }
366}
367
368impl ArrayHash for TurboQuantData {
369 fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
370 self.dimension.hash(state);
371 self.bit_width.hash(state);
372 self.num_rounds.hash(state);
373 }
374}
375
376impl ArrayEq for TurboQuantData {
377 fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
378 self.dimension == other.dimension
379 && self.bit_width == other.bit_width
380 && self.num_rounds == other.num_rounds
381 }
382}