vortex_tensor/encodings/turboquant/
compress.rs1use vortex_array::ArrayRef;
14use vortex_array::ArrayView;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::arrays::Extension;
18use vortex_array::arrays::FixedSizeListArray;
19use vortex_array::arrays::PrimitiveArray;
20use vortex_array::arrays::dict::DictArray;
21use vortex_array::arrays::extension::ExtensionArrayExt;
22use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
23use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
24use vortex_array::dtype::Nullability;
25use vortex_array::validity::Validity;
26use vortex_buffer::Buffer;
27use vortex_buffer::BufferMut;
28use vortex_error::VortexExpect;
29use vortex_error::VortexResult;
30use vortex_error::vortex_ensure;
31
32use crate::encodings::turboquant::MAX_BIT_WIDTH;
33use crate::encodings::turboquant::MIN_DIMENSION;
34use crate::encodings::turboquant::centroids::compute_centroid_boundaries;
35use crate::encodings::turboquant::centroids::compute_or_get_centroids;
36use crate::encodings::turboquant::centroids::find_nearest_centroid;
37use crate::scalar_fns::l2_denorm::L2Denorm;
38use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
39use crate::scalar_fns::sorf_transform::SorfMatrix;
40use crate::scalar_fns::sorf_transform::SorfOptions;
41use crate::scalar_fns::sorf_transform::SorfTransform;
42use crate::types::vector::AnyVector;
43use crate::types::vector::Vector;
44use crate::utils::cast_to_f32;
45
46#[derive(Clone, Debug)]
48pub struct TurboQuantConfig {
49 pub bit_width: u8,
51 pub seed: u64,
53 pub num_rounds: u8,
55}
56
57impl Default for TurboQuantConfig {
58 fn default() -> Self {
59 Self {
60 bit_width: MAX_BIT_WIDTH,
61 seed: 42,
62 num_rounds: 3,
63 }
64 }
65}
66
67pub fn turboquant_encode(
86 input: ArrayRef,
87 config: &TurboQuantConfig,
88 ctx: &mut ExecutionCtx,
89) -> VortexResult<ArrayRef> {
90 let l2_denorm = normalize_as_l2_denorm(input, ctx)?;
92 let normalized = l2_denorm.child_at(0).clone();
93 let norms = l2_denorm.child_at(1).clone();
94 let num_rows = l2_denorm.len();
95
96 let normalized_ext = normalized
97 .as_opt::<Extension>()
98 .vortex_expect("normalize_as_l2_denorm always produces an Extension array child");
99
100 let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?;
102
103 Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
106}
107
108pub unsafe fn turboquant_encode_unchecked(
122 ext: ArrayView<Extension>,
123 config: &TurboQuantConfig,
124 ctx: &mut ExecutionCtx,
125) -> VortexResult<ArrayRef> {
126 let ext_dtype = ext.dtype().clone();
127 let storage = ext.storage_array();
128 let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
129
130 vortex_ensure!(
131 config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH,
132 "bit_width must be 1-{MAX_BIT_WIDTH}, got {}",
133 config.bit_width
134 );
135 let dimension = fsl.list_size();
136 vortex_ensure!(
137 dimension >= MIN_DIMENSION,
138 "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}",
139 );
140
141 let vector_metadata = ext_dtype.as_extension().metadata::<AnyVector>();
142 let element_ptype = vector_metadata.element_ptype();
143
144 let seed = config.seed;
145 let num_rows = fsl.len();
146
147 if fsl.is_empty() {
148 let padded_dim = dimension.next_power_of_two();
149 let empty_codes = PrimitiveArray::empty::<u8>(Nullability::NonNullable);
150 let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
151 let empty_dict =
152 DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?;
153 let empty_fsl = FixedSizeListArray::try_new(
154 empty_dict.into_array(),
155 padded_dim,
156 Validity::NonNullable,
157 0,
158 )?;
159 let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?;
160
161 let sorf_options = SorfOptions {
162 seed,
163 num_rounds: config.num_rounds,
164 dimensions: dimension,
165 element_ptype,
166 };
167 return Ok(
168 SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(),
169 );
170 }
171
172 let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
173 let quantized_fsl =
174 build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?;
175 let padded_vector = Vector::try_new_vector_array(quantized_fsl)?;
176
177 let sorf_options = SorfOptions {
178 seed,
179 num_rounds: config.num_rounds,
180 dimensions: dimension,
181 element_ptype,
182 };
183 Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
184}
185
186struct QuantizationResult {
188 centroids: Buffer<f32>,
189 all_indices: Buffer<u8>,
190 padded_dim: usize,
191}
192
193fn turboquant_quantize_core(
199 fsl: &FixedSizeListArray,
200 seed: u64,
201 bit_width: u8,
202 num_rounds: u8,
203 ctx: &mut ExecutionCtx,
204) -> VortexResult<QuantizationResult> {
205 let dimension = fsl.list_size() as usize;
206 let num_rows = fsl.len();
207
208 let padded_dim = dimension.next_power_of_two();
209 let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?;
210 let padded_dim_u32 =
211 u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
212
213 let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
214 let f32_elements = cast_to_f32(elements_prim)?;
215
216 let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?;
217 let boundaries = compute_centroid_boundaries(¢roids);
218
219 let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
220 let mut padded = vec![0.0f32; padded_dim];
221 let mut rotated = vec![0.0f32; padded_dim];
222
223 let f32_slice = f32_elements.as_slice();
224 for row in 0..num_rows {
225 let x = &f32_slice[row * dimension..(row + 1) * dimension];
226
227 padded[..dimension].copy_from_slice(x);
229 padded[dimension..].fill(0.0);
230
231 rotation.rotate(&padded, &mut rotated);
232
233 for j in 0..padded_dim {
234 all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
235 }
236 }
237
238 Ok(QuantizationResult {
239 centroids,
240 all_indices: all_indices.freeze(),
241 padded_dim,
242 })
243}
244
245fn build_quantized_fsl(
251 num_rows: usize,
252 all_indices: Buffer<u8>,
253 centroids: Buffer<f32>,
254 padded_dim: usize,
255) -> VortexResult<ArrayRef> {
256 let codes = PrimitiveArray::new::<u8>(all_indices, Validity::NonNullable);
257 let centroids_array = PrimitiveArray::new::<f32>(centroids, Validity::NonNullable);
258
259 let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?;
260
261 let padded_dim_u32 =
262 u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
263 Ok(FixedSizeListArray::try_new(
264 dict.into_array(),
265 padded_dim_u32,
266 Validity::NonNullable,
267 num_rows,
268 )?
269 .into_array())
270}