vortex_tensor/encodings/turboquant/
compress.rs1use vortex_array::ArrayRef;
14use vortex_array::ArrayView;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::arrays::Extension;
18use vortex_array::arrays::ExtensionArray;
19use vortex_array::arrays::FixedSizeListArray;
20use vortex_array::arrays::PrimitiveArray;
21use vortex_array::arrays::dict::DictArray;
22use vortex_array::arrays::extension::ExtensionArrayExt;
23use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
24use vortex_array::dtype::Nullability;
25use vortex_array::dtype::extension::ExtDType;
26use vortex_array::extension::EmptyMetadata;
27use vortex_array::validity::Validity;
28use vortex_buffer::BufferMut;
29use vortex_error::VortexExpect;
30use vortex_error::VortexResult;
31use vortex_error::vortex_ensure;
32
33use crate::encodings::turboquant::MAX_BIT_WIDTH;
34use crate::encodings::turboquant::MIN_DIMENSION;
35use crate::encodings::turboquant::centroids::compute_centroid_boundaries;
36use crate::encodings::turboquant::centroids::find_nearest_centroid;
37use crate::encodings::turboquant::centroids::get_centroids;
38use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows;
39use crate::scalar_fns::sorf_transform::SorfMatrix;
40use crate::scalar_fns::sorf_transform::SorfOptions;
41use crate::scalar_fns::sorf_transform::SorfTransform;
42use crate::utils::cast_to_f32;
43use crate::vector::AnyVector;
44use crate::vector::Vector;
45
46#[derive(Clone, Debug)]
48pub struct TurboQuantConfig {
49 pub bit_width: u8,
51 pub seed: Option<u64>,
53 pub num_rounds: u8,
55}
56
57impl Default for TurboQuantConfig {
58 fn default() -> Self {
59 Self {
60 bit_width: MAX_BIT_WIDTH,
61 seed: Some(42),
62 num_rounds: 3,
63 }
64 }
65}
66
67struct QuantizationResult {
69 centroids: Vec<f32>,
70 all_indices: BufferMut<u8>,
71 padded_dim: usize,
72}
73
74fn turboquant_quantize_core(
80 fsl: &FixedSizeListArray,
81 seed: u64,
82 bit_width: u8,
83 num_rounds: u8,
84 ctx: &mut ExecutionCtx,
85) -> VortexResult<QuantizationResult> {
86 let dimension =
87 usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize");
88 let num_rows = fsl.len();
89
90 let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?;
91 let padded_dim = rotation.padded_dim();
92 let padded_dim_u32 =
93 u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
94
95 let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
96 let f32_elements = cast_to_f32(elements_prim)?;
97
98 let centroids = get_centroids(padded_dim_u32, bit_width)?;
99 let boundaries = compute_centroid_boundaries(¢roids);
100
101 let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
102 let mut padded = vec![0.0f32; padded_dim];
103 let mut rotated = vec![0.0f32; padded_dim];
104
105 let f32_slice = f32_elements.as_slice();
106 for row in 0..num_rows {
107 let x = &f32_slice[row * dimension..(row + 1) * dimension];
108
109 padded[..dimension].copy_from_slice(x);
111 padded[dimension..].fill(0.0);
112
113 rotation.rotate(&padded, &mut rotated);
114
115 for j in 0..padded_dim {
116 all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
117 }
118 }
119
120 Ok(QuantizationResult {
121 centroids,
122 all_indices,
123 padded_dim,
124 })
125}
126
127fn build_quantized_fsl(
133 num_rows: usize,
134 all_indices: BufferMut<u8>,
135 centroids: &[f32],
136 padded_dim: usize,
137) -> VortexResult<ArrayRef> {
138 let codes = PrimitiveArray::new::<u8>(all_indices.freeze(), Validity::NonNullable);
139
140 let mut centroids_buf = BufferMut::<f32>::with_capacity(centroids.len());
141 centroids_buf.extend_from_slice(centroids);
142 let centroids_array = PrimitiveArray::new::<f32>(centroids_buf.freeze(), Validity::NonNullable);
143
144 let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?;
145
146 let padded_dim_u32 =
147 u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
148 Ok(FixedSizeListArray::try_new(
149 dict.into_array(),
150 padded_dim_u32,
151 Validity::NonNullable,
152 num_rows,
153 )?
154 .into_array())
155}
156
157pub fn turboquant_encode(
175 ext: ArrayView<Extension>,
176 config: &TurboQuantConfig,
177 ctx: &mut ExecutionCtx,
178) -> VortexResult<ArrayRef> {
179 let ext_dtype = ext.dtype().clone();
180
181 vortex_ensure!(
182 !ext_dtype.is_nullable(),
183 "TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}",
184 );
185
186 validate_l2_normalized_rows(ext.as_ref(), ctx)?;
187
188 unsafe { turboquant_encode_unchecked(ext, config, ctx) }
190}
191
192pub unsafe fn turboquant_encode_unchecked(
206 ext: ArrayView<Extension>,
207 config: &TurboQuantConfig,
208 ctx: &mut ExecutionCtx,
209) -> VortexResult<ArrayRef> {
210 let ext_dtype = ext.dtype().clone();
211 let storage = ext.storage_array();
212 let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
213
214 vortex_ensure!(
215 config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH,
216 "bit_width must be 1-{MAX_BIT_WIDTH}, got {}",
217 config.bit_width
218 );
219 let dimension = fsl.list_size();
220 vortex_ensure!(
221 dimension >= MIN_DIMENSION,
222 "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}",
223 );
224
225 let vector_metadata = ext_dtype.as_extension().metadata::<AnyVector>();
226 let element_ptype = vector_metadata.element_ptype();
227
228 let seed = config.seed.unwrap_or(42);
229 let num_rows = fsl.len();
230
231 if fsl.is_empty() {
232 let padded_dim = dimension.next_power_of_two();
233 let empty_codes = PrimitiveArray::empty::<u8>(Nullability::NonNullable);
234 let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
235 let empty_dict =
236 DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?;
237 let empty_fsl = FixedSizeListArray::try_new(
238 empty_dict.into_array(),
239 padded_dim,
240 Validity::NonNullable,
241 0,
242 )?;
243 let empty_padded_vector = wrap_padded_as_vector(empty_fsl.into_array())?;
244
245 let sorf_options = SorfOptions {
246 seed,
247 num_rounds: config.num_rounds,
248 dimension,
249 element_ptype,
250 };
251 return Ok(
252 SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(),
253 );
254 }
255
256 let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
257 let quantized_fsl =
258 build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?;
259 let padded_vector = wrap_padded_as_vector(quantized_fsl)?;
260
261 let sorf_options = SorfOptions {
262 seed,
263 num_rounds: config.num_rounds,
264 dimension,
265 element_ptype,
266 };
267 Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
268}
269
270fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult<ArrayRef> {
273 let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
274 Ok(ExtensionArray::new(ext_dtype, fsl).into_array())
275}