vortex_tensor/encodings/turboquant/array/
data.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::sync::Arc;
7
8use vortex_array::ArrayRef;
9use vortex_array::TypedArrayRef;
10use vortex_array::dtype::DType;
11use vortex_array::dtype::Nullability;
12use vortex_array::dtype::PType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_ensure_eq;
17
18use crate::encodings::turboquant::array::slots::Slot;
19use crate::encodings::turboquant::vtable::TurboQuant;
20
21#[derive(Clone, Debug)]
35pub struct TurboQuantData {
36 pub(crate) dimension: u32,
40
41 pub(crate) bit_width: u8,
45
46 pub(crate) num_rounds: u8,
50}
51
52impl Display for TurboQuantData {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 write!(
55 f,
56 "dimension: {}, bit_width: {}, num_rounds: {}",
57 self.dimension, self.bit_width, self.num_rounds
58 )
59 }
60}
61
62impl TurboQuantData {
63 pub fn try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> VortexResult<Self> {
71 vortex_ensure!(
72 dimension >= TurboQuant::MIN_DIMENSION,
73 "TurboQuant requires dimension >= {}, got {dimension}",
74 TurboQuant::MIN_DIMENSION
75 );
76 vortex_ensure!(
77 bit_width <= TurboQuant::MAX_BIT_WIDTH,
78 "bit_width is expected to be between 0 and {}, got {bit_width}",
79 TurboQuant::MAX_BIT_WIDTH
80 );
81
82 Ok(Self {
83 dimension,
84 bit_width,
85 num_rounds,
86 })
87 }
88
89 pub unsafe fn new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self {
101 Self {
102 dimension,
103 bit_width,
104 num_rounds,
105 }
106 }
107
108 pub fn validate(
112 dtype: &DType,
113 codes: &ArrayRef,
114 centroids: &ArrayRef,
115 rotation_signs: &ArrayRef,
116 ) -> VortexResult<()> {
117 let vector_metadata = TurboQuant::validate_dtype(dtype)?;
118 let dimension = vector_metadata.dimensions();
119 let padded_dim = dimension.next_power_of_two();
120
121 vortex_ensure!(
124 !dtype.is_nullable(),
125 "TurboQuant dtype must be non-nullable, got {dtype}",
126 );
127
128 let expected_codes_dtype = DType::FixedSizeList(
130 Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
131 padded_dim,
132 Nullability::NonNullable,
133 );
134 vortex_ensure_eq!(
135 *codes.dtype(),
136 expected_codes_dtype,
137 "codes dtype does not match expected {expected_codes_dtype}",
138 );
139
140 let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
142 vortex_ensure_eq!(
143 *centroids.dtype(),
144 centroids_dtype,
145 "centroids dtype must be non-nullable f32",
146 );
147
148 let expected_signs_dtype = DType::FixedSizeList(
151 Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
152 padded_dim,
153 Nullability::NonNullable,
154 );
155 vortex_ensure_eq!(
156 *rotation_signs.dtype(),
157 expected_signs_dtype,
158 "rotation_signs dtype does not match expected {expected_signs_dtype}",
159 );
160 let num_rows = codes.len();
162 if num_rows == 0 {
163 vortex_ensure!(
164 centroids.is_empty(),
165 "degenerate TurboQuant must have empty centroids, got length {}",
166 centroids.len()
167 );
168 vortex_ensure!(
169 rotation_signs.is_empty(),
170 "degenerate TurboQuant must have empty rotation_signs, got length {}",
171 rotation_signs.len()
172 );
173 return Ok(());
174 }
175
176 vortex_ensure!(
177 !rotation_signs.is_empty(),
178 "rotation_signs must have at least 1 round"
179 );
180
181 let num_centroids = centroids.len();
183 vortex_ensure!(
184 num_centroids.is_power_of_two()
185 && (2..=TurboQuant::MAX_CENTROIDS).contains(&num_centroids),
186 "centroids length must be a power of 2 in [2, {}], got {num_centroids}",
187 TurboQuant::MAX_CENTROIDS
188 );
189
190 #[expect(
191 clippy::cast_possible_truncation,
192 reason = "Guaranteed to be [1,8] by the preceding power-of-2 and range checks."
193 )]
194 let bit_width = num_centroids.trailing_zeros() as u8;
195 vortex_ensure!(
196 (1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width),
197 "derived bit_width must be 1-{}, got {bit_width}",
198 TurboQuant::MAX_BIT_WIDTH
199 );
200
201 Ok(())
202 }
203
204 pub(crate) fn make_slots(
205 codes: ArrayRef,
206 centroids: ArrayRef,
207 rotation_signs: ArrayRef,
208 ) -> Vec<Option<ArrayRef>> {
209 let mut slots = vec![None; Slot::COUNT];
210 slots[Slot::Codes as usize] = Some(codes);
211 slots[Slot::Centroids as usize] = Some(centroids);
212 slots[Slot::RotationSigns as usize] = Some(rotation_signs);
213 slots
214 }
215
216 pub fn dimension(&self) -> u32 {
219 self.dimension
220 }
221
222 pub fn bit_width(&self) -> u8 {
224 self.bit_width
225 }
226
227 pub fn num_rounds(&self) -> u8 {
229 self.num_rounds
230 }
231
232 pub fn padded_dim(&self) -> u32 {
237 self.dimension.next_power_of_two()
238 }
239}
240
241pub trait TurboQuantArrayExt: TypedArrayRef<TurboQuant> {
242 fn codes(&self) -> &ArrayRef {
243 self.as_ref().slots()[Slot::Codes as usize]
244 .as_ref()
245 .vortex_expect("TurboQuantArray codes slot")
246 }
247
248 fn centroids(&self) -> &ArrayRef {
249 self.as_ref().slots()[Slot::Centroids as usize]
250 .as_ref()
251 .vortex_expect("TurboQuantArray centroids slot")
252 }
253
254 fn rotation_signs(&self) -> &ArrayRef {
255 self.as_ref().slots()[Slot::RotationSigns as usize]
256 .as_ref()
257 .vortex_expect("TurboQuantArray rotation_signs slot")
258 }
259}
260
261impl<T: TypedArrayRef<TurboQuant>> TurboQuantArrayExt for T {}