vortex_tensor/encodings/turboquant/scheme/
compress.rs1use vortex_array::ArrayRef;
12use vortex_array::ArrayView;
13use vortex_array::ExecutionCtx;
14use vortex_array::IntoArray;
15use vortex_array::arrays::Extension;
16use vortex_array::arrays::FixedSizeListArray;
17use vortex_array::arrays::PrimitiveArray;
18use vortex_array::arrays::extension::ExtensionArrayExt;
19use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
20use vortex_array::dtype::DType;
21use vortex_array::dtype::Nullability;
22use vortex_array::dtype::PType;
23use vortex_array::validity::Validity;
24use vortex_buffer::BufferMut;
25use vortex_error::VortexExpect;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_ensure;
29use vortex_fastlanes::bitpack_compress::bitpack_encode;
30
31use crate::encodings::turboquant::TurboQuant;
32use crate::encodings::turboquant::array::centroids::compute_centroid_boundaries;
33use crate::encodings::turboquant::array::centroids::find_nearest_centroid;
34use crate::encodings::turboquant::array::centroids::get_centroids;
35use crate::encodings::turboquant::array::rotation::RotationMatrix;
36use crate::encodings::turboquant::vtable::TurboQuantArray;
37use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows;
38
39#[derive(Clone, Debug)]
41pub struct TurboQuantConfig {
42 pub bit_width: u8,
44 pub seed: Option<u64>,
46 pub num_rounds: u8,
48}
49
50impl Default for TurboQuantConfig {
51 fn default() -> Self {
52 Self {
53 bit_width: TurboQuant::MAX_BIT_WIDTH,
54 seed: Some(42),
55 num_rounds: 3,
56 }
57 }
58}
59
60fn extract_f32_elements(
64 fsl: &FixedSizeListArray,
65 ctx: &mut ExecutionCtx,
66) -> VortexResult<PrimitiveArray> {
67 let elements = fsl.elements();
68 let primitive = elements.clone().execute::<PrimitiveArray>(ctx)?;
69 let ptype = primitive.ptype();
70
71 match ptype {
72 PType::F16 => Ok(primitive
73 .as_slice::<half::f16>()
74 .iter()
75 .map(|&v| f32::from(v))
76 .collect()),
77 PType::F32 => Ok(primitive),
78 PType::F64 => Ok(primitive
79 .as_slice::<f64>()
80 .iter()
81 .map(|&v| {
82 #[expect(
83 clippy::cast_possible_truncation,
84 reason = "TurboQuant quantization operates in f32, so f64 inputs are intentionally downcast"
85 )]
86 let v = v as f32;
87 v
88 })
89 .collect()),
90 _ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"),
91 }
92}
93
94struct QuantizationResult {
96 rotation: RotationMatrix,
97 centroids: Vec<f32>,
98 all_indices: BufferMut<u8>,
99 padded_dim: usize,
100}
101
102fn turboquant_quantize_core(
108 fsl: &FixedSizeListArray,
109 seed: u64,
110 bit_width: u8,
111 num_rounds: u8,
112 ctx: &mut ExecutionCtx,
113) -> VortexResult<QuantizationResult> {
114 let dimension =
115 usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize");
116 let num_rows = fsl.len();
117
118 let rotation = RotationMatrix::try_new(seed, dimension, num_rounds as usize)?;
119 let padded_dim = rotation.padded_dim();
120 let padded_dim_u32 =
121 u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
122
123 let f32_elements = extract_f32_elements(fsl, ctx)?;
124
125 let centroids = get_centroids(padded_dim_u32, bit_width)?;
126 let boundaries = compute_centroid_boundaries(¢roids);
127
128 let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
129 let mut padded = vec![0.0f32; padded_dim];
130 let mut rotated = vec![0.0f32; padded_dim];
131
132 let f32_slice = f32_elements.as_slice::<f32>();
133 for row in 0..num_rows {
134 let x = &f32_slice[row * dimension..(row + 1) * dimension];
135
136 padded[..dimension].copy_from_slice(x);
138 padded[dimension..].fill(0.0);
139
140 rotation.rotate(&padded, &mut rotated);
141
142 for j in 0..padded_dim {
143 all_indices.push(find_nearest_centroid(rotated[j], &boundaries));
144 }
145 }
146
147 Ok(QuantizationResult {
148 rotation,
149 centroids,
150 all_indices,
151 padded_dim,
152 })
153}
154
155fn build_turboquant(
159 num_rows: usize,
160 core: QuantizationResult,
161 ext_dtype: &DType,
162) -> VortexResult<TurboQuantArray> {
163 let padded_dim = core.padded_dim;
164 let padded_dim_u32 =
165 u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32");
166 let codes_elements =
167 PrimitiveArray::new::<u8>(core.all_indices.freeze(), Validity::NonNullable);
168 let codes = FixedSizeListArray::try_new(
169 codes_elements.into_array(),
170 padded_dim_u32,
171 Validity::NonNullable,
172 num_rows,
173 )?
174 .into_array();
175
176 let mut centroids_buf = BufferMut::<f32>::with_capacity(core.centroids.len());
179 centroids_buf.extend_from_slice(&core.centroids);
180 let centroids_array =
181 PrimitiveArray::new::<f32>(centroids_buf.freeze(), Validity::NonNullable).into_array();
182
183 let rotation_signs = bitpack_rotation_signs(&core.rotation)?;
184
185 TurboQuant::try_new_array(ext_dtype.clone(), codes, centroids_array, rotation_signs)
186}
187
188pub fn turboquant_encode(
206 ext: ArrayView<Extension>,
207 config: &TurboQuantConfig,
208 ctx: &mut ExecutionCtx,
209) -> VortexResult<ArrayRef> {
210 let ext_dtype = ext.dtype().clone();
211
212 vortex_ensure!(
213 !ext_dtype.is_nullable(),
214 "TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}",
215 );
216
217 validate_l2_normalized_rows(ext.as_ref().clone(), ctx)?;
218
219 unsafe { turboquant_encode_unchecked(ext, config, ctx) }
221}
222
223pub unsafe fn turboquant_encode_unchecked(
236 ext: ArrayView<Extension>,
237 config: &TurboQuantConfig,
238 ctx: &mut ExecutionCtx,
239) -> VortexResult<ArrayRef> {
240 let ext_dtype = ext.dtype().clone();
241 let storage = ext.storage_array();
242 let fsl = storage.clone().execute::<FixedSizeListArray>(ctx)?;
243
244 vortex_ensure!(
245 config.bit_width >= 1 && config.bit_width <= TurboQuant::MAX_BIT_WIDTH,
246 "bit_width must be 1-{}, got {}",
247 TurboQuant::MAX_BIT_WIDTH,
248 config.bit_width
249 );
250 let dimension = fsl.list_size();
251 vortex_ensure!(
252 dimension >= TurboQuant::MIN_DIMENSION,
253 "TurboQuant requires dimension >= {}, got {dimension}",
254 TurboQuant::MIN_DIMENSION
255 );
256
257 if fsl.is_empty() {
258 let padded_dim = dimension.next_power_of_two();
259 let empty_codes = FixedSizeListArray::try_new(
260 PrimitiveArray::empty::<u8>(Nullability::NonNullable).into_array(),
261 padded_dim,
262 Validity::NonNullable,
263 0,
264 )?;
265
266 let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
267 let empty_signs = FixedSizeListArray::try_new(
268 PrimitiveArray::empty::<u8>(Nullability::NonNullable).into_array(),
269 padded_dim,
270 Validity::NonNullable,
271 0,
272 )?;
273
274 return Ok(TurboQuant::try_new_array(
275 ext_dtype,
276 empty_codes.into_array(),
277 empty_centroids.into_array(),
278 empty_signs.into_array(),
279 )?
280 .into_array());
281 }
282
283 let seed = config.seed.unwrap_or(42);
284 let num_rows = fsl.len();
285 let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
286
287 Ok(build_turboquant(num_rows, core, &ext_dtype)?.into_array())
288}
289
290fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult<ArrayRef> {
296 let signs_u8 = rotation.export_inverse_signs_u8();
297 let num_rounds = rotation.num_rounds();
298 let padded_dim = u32::try_from(rotation.padded_dim()).vortex_expect("padded_dim fits in u32");
299
300 let mut buf = BufferMut::<u8>::with_capacity(signs_u8.len());
301 buf.extend_from_slice(&signs_u8);
302 let prim = PrimitiveArray::new::<u8>(buf.freeze(), Validity::NonNullable);
303 let bitpacked = bitpack_encode(&prim, 1, None)?;
304
305 let fsl = FixedSizeListArray::try_new(
306 bitpacked.into_array(),
307 padded_dim,
308 Validity::NonNullable,
309 num_rounds,
310 )?;
311 Ok(fsl.into_array())
312}