vortex_tensor/
vector_search.rs1use vortex_array::ArrayRef;
46use vortex_array::ExecutionCtx;
47use vortex_array::IntoArray;
48use vortex_array::arrays::ConstantArray;
49use vortex_array::arrays::Extension;
50use vortex_array::arrays::ExtensionArray;
51use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
52use vortex_array::builtins::ArrayBuiltins;
53use vortex_array::dtype::DType;
54use vortex_array::dtype::NativePType;
55use vortex_array::dtype::Nullability;
56use vortex_array::dtype::extension::ExtDType;
57use vortex_array::extension::EmptyMetadata;
58use vortex_array::scalar::PValue;
59use vortex_array::scalar::Scalar;
60use vortex_array::scalar_fn::fns::operators::Operator;
61use vortex_error::VortexResult;
62use vortex_error::vortex_bail;
63
64use crate::encodings::turboquant::TurboQuantConfig;
65use crate::encodings::turboquant::turboquant_encode_unchecked;
66use crate::scalar_fns::cosine_similarity::CosineSimilarity;
67use crate::scalar_fns::l2_denorm::L2Denorm;
68use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
69use crate::vector::Vector;
70
71pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
89 let l2_denorm = normalize_as_l2_denorm(data, ctx)?;
90 let normalized = l2_denorm.child_at(0).clone();
91 let norms = l2_denorm.child_at(1).clone();
92 let num_rows = l2_denorm.len();
93
94 let Some(normalized_ext) = normalized.as_opt::<Extension>() else {
95 vortex_bail!("normalize_as_l2_denorm must produce an Extension array child");
96 };
97
98 let config = TurboQuantConfig::default();
99 let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?;
102
103 Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
104}
105
106pub fn build_constant_query_vector<T: NativePType + Into<PValue>>(
118 query: &[T],
119 num_rows: usize,
120) -> VortexResult<ArrayRef> {
121 let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable);
122
123 let children: Vec<Scalar> = query
124 .iter()
125 .map(|&v| Scalar::primitive(v, Nullability::NonNullable))
126 .collect();
127 let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
128
129 let storage = ConstantArray::new(storage_scalar, num_rows).into_array();
130
131 let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
132 Ok(ExtensionArray::new(ext_dtype, storage).into_array())
133}
134
135pub fn build_similarity_search_tree<T: NativePType + Into<PValue>>(
161 data: ArrayRef,
162 query: &[T],
163 threshold: T,
164) -> VortexResult<ArrayRef> {
165 let num_rows = data.len();
166 let query_vec = build_constant_query_vector(query, num_rows)?;
167
168 let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array();
169
170 let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable);
171 let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();
172
173 cosine.binary(threshold_array, Operator::Gt)
174}
175
176#[cfg(test)]
177mod tests {
178 use vortex_array::ArrayRef;
179 use vortex_array::IntoArray;
180 use vortex_array::VortexSessionExecute;
181 use vortex_array::arrays::BoolArray;
182 use vortex_array::arrays::Extension;
183 use vortex_array::arrays::ExtensionArray;
184 use vortex_array::arrays::FixedSizeListArray;
185 use vortex_array::arrays::PrimitiveArray;
186 use vortex_array::arrays::bool::BoolArrayExt;
187 use vortex_array::dtype::extension::ExtDType;
188 use vortex_array::extension::EmptyMetadata;
189 use vortex_array::session::ArraySession;
190 use vortex_array::validity::Validity;
191 use vortex_buffer::BufferMut;
192 use vortex_error::VortexResult;
193 use vortex_session::VortexSession;
194
195 use super::build_constant_query_vector;
196 use super::build_similarity_search_tree;
197 use super::compress_turboquant;
198 use crate::vector::Vector;
199
200 fn vector_array(dim: u32, values: &[f32]) -> VortexResult<ArrayRef> {
203 let dim_usize = dim as usize;
204 assert_eq!(values.len() % dim_usize, 0);
205 let num_rows = values.len() / dim_usize;
206
207 let mut buf = BufferMut::<f32>::with_capacity(values.len());
208 for &v in values {
209 buf.push(v);
210 }
211 let elements = PrimitiveArray::new::<f32>(buf.freeze(), Validity::NonNullable);
212 let fsl = FixedSizeListArray::try_new(
213 elements.into_array(),
214 dim,
215 Validity::NonNullable,
216 num_rows,
217 )?;
218
219 let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
220 Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
221 }
222
223 fn test_session() -> VortexSession {
224 VortexSession::empty().with::<ArraySession>()
225 }
226
227 #[test]
228 fn constant_query_vector_has_vector_extension_dtype() -> VortexResult<()> {
229 let query = vec![1.0f32, 0.0, 0.0, 0.0];
230 let rhs = build_constant_query_vector(&query, 5)?;
231
232 assert_eq!(rhs.len(), 5);
233 assert!(rhs.as_opt::<Extension>().is_some());
234 Ok(())
235 }
236
237 #[test]
238 fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> {
239 let data = vector_array(
241 3,
242 &[
243 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ],
248 )?;
249 let query = [1.0f32, 0.0, 0.0];
250
251 let tree = build_similarity_search_tree(data, &query, 0.5)?;
252 let mut ctx = test_session().create_execution_ctx();
253 let result: BoolArray = tree.execute(&mut ctx)?;
254
255 let bits = result.to_bit_buffer();
256 assert_eq!(bits.len(), 4);
257 assert!(bits.value(0));
258 assert!(!bits.value(1));
259 assert!(!bits.value(2));
260 assert!(bits.value(3));
261 Ok(())
262 }
263
264 #[test]
265 fn turboquant_roundtrip_preserves_ranking() -> VortexResult<()> {
266 const DIM: u32 = 128;
269 const NUM_ROWS: usize = 6;
270
271 let mut values = Vec::<f32>::with_capacity(NUM_ROWS * DIM as usize);
272 let query: Vec<f32> = (0..DIM as usize)
273 .map(|i| ((i as f32) * 0.017).sin())
274 .collect();
275
276 values.extend_from_slice(&query);
278 for (i, q) in query.iter().enumerate() {
280 values.push(q + 0.05 * ((i as f32) * 0.03).cos());
281 }
282 for row in 2..NUM_ROWS {
284 for i in 0..DIM as usize {
285 values.push(((row as f32 * 1.3 + i as f32) * 0.07).sin());
286 }
287 }
288
289 let data = vector_array(DIM, &values)?;
290 let mut ctx = test_session().create_execution_ctx();
291 let compressed = compress_turboquant(data, &mut ctx)?;
292 assert_eq!(compressed.len(), NUM_ROWS);
293
294 let tree = build_similarity_search_tree(compressed, &query, 0.95)?;
296 let result: BoolArray = tree.execute(&mut ctx)?;
297 let bits = result.to_bit_buffer();
298 assert_eq!(bits.len(), NUM_ROWS);
299 assert!(
300 bits.value(0),
301 "row 0 (identical to query) must match at threshold 0.95 even after TurboQuant"
302 );
303 Ok(())
304 }
305}