Skip to main content

vortex_tensor/
vector_search.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Reusable helpers for building brute-force vector similarity search expressions over
5//! [`Vector`] extension arrays.
6//!
7//! This module exposes three small building blocks that together make it straightforward to
8//! stand up a cosine-similarity-plus-threshold scan on top of a prepared data array:
9//!
10//! - [`compress_turboquant`] applies the canonical TurboQuant encoding pipeline
11//!   (`L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)`) to a raw
12//!   `Vector<dim, f32>` array without requiring the caller to plumb the
13//!   `unstable_encodings` feature flag on the `vortex` facade.
14//! - [`build_constant_query_vector`] wraps a single query vector into a
15//!   [`Vector`] extension array whose storage is a [`ConstantArray`] broadcast
16//!   across `num_rows` rows. This is the shape expected by
17//!   [`CosineSimilarity::try_new_array`] for the RHS of a database-vs-query scan.
18//! - [`build_similarity_search_tree`] wires everything together into a lazy
19//!   `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression.
20//!
21//! Executing the tree from [`build_similarity_search_tree`] into a
22//! [`BoolArray`](vortex_array::arrays::BoolArray) yields one boolean per row indicating whether
23//! that row's cosine similarity to the query exceeds `threshold`.
24//!
25//! # Example
26//!
27//! ```ignore
28//! use vortex_array::{ArrayRef, VortexSessionExecute};
29//! use vortex_array::arrays::BoolArray;
30//! use vortex_session::VortexSession;
31//! use vortex_tensor::vector_search::{build_similarity_search_tree, compress_turboquant};
32//!
33//! fn run(session: &VortexSession, data: ArrayRef, query: &[f32]) -> anyhow::Result<()> {
34//!     let mut ctx = session.create_execution_ctx();
35//!     let data = compress_turboquant(data, &mut ctx)?;
36//!     let tree = build_similarity_search_tree(data, query, 0.8)?;
37//!     let _matches: BoolArray = tree.execute(&mut ctx)?;
38//!     Ok(())
39//! }
40//! ```
41//!
42//! [`Vector`]: crate::vector::Vector
43//! [`CosineSimilarity::try_new_array`]: crate::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array
44
45use vortex_array::ArrayRef;
46use vortex_array::ExecutionCtx;
47use vortex_array::IntoArray;
48use vortex_array::arrays::ConstantArray;
49use vortex_array::arrays::Extension;
50use vortex_array::arrays::ExtensionArray;
51use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
52use vortex_array::builtins::ArrayBuiltins;
53use vortex_array::dtype::DType;
54use vortex_array::dtype::NativePType;
55use vortex_array::dtype::Nullability;
56use vortex_array::dtype::extension::ExtDType;
57use vortex_array::extension::EmptyMetadata;
58use vortex_array::scalar::PValue;
59use vortex_array::scalar::Scalar;
60use vortex_array::scalar_fn::fns::operators::Operator;
61use vortex_error::VortexResult;
62use vortex_error::vortex_bail;
63
64use crate::encodings::turboquant::TurboQuantConfig;
65use crate::encodings::turboquant::turboquant_encode_unchecked;
66use crate::scalar_fns::cosine_similarity::CosineSimilarity;
67use crate::scalar_fns::l2_denorm::L2Denorm;
68use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
69use crate::vector::Vector;
70
71/// Apply the canonical TurboQuant encoding pipeline to a `Vector<dim, f32>` array.
72///
73/// The returned array has the shape
74/// `L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)` — exactly what
75/// [`crate::encodings::turboquant::TurboQuantScheme`] produces when invoked through
76/// `BtrBlocksCompressorBuilder::with_turboquant()`, but without requiring callers to enable
77/// the `unstable_encodings` feature on the `vortex` facade.
78///
79/// The input `data` must be a [`Vector`] extension array whose element type is `f32` and whose
80/// dimensionality is at least
81/// [`turboquant::MIN_DIMENSION`](crate::encodings::turboquant::MIN_DIMENSION). The TurboQuant
82/// configuration used is [`TurboQuantConfig::default()`] (8-bit codes, 3 SORF rounds, seed 42).
83///
84/// # Errors
85///
86/// Returns an error if `data` is not a [`Vector`] extension array, if normalization fails, or
87/// if the underlying TurboQuant encoder rejects the input shape.
88pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
89    let l2_denorm = normalize_as_l2_denorm(data, ctx)?;
90    let normalized = l2_denorm.child_at(0).clone();
91    let norms = l2_denorm.child_at(1).clone();
92    let num_rows = l2_denorm.len();
93
94    let Some(normalized_ext) = normalized.as_opt::<Extension>() else {
95        vortex_bail!("normalize_as_l2_denorm must produce an Extension array child");
96    };
97
98    let config = TurboQuantConfig::default();
99    // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero), which is
100    // the invariant `turboquant_encode_unchecked` expects.
101    let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?;
102
103    Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
104}
105
106/// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a single
107/// query vector across `num_rows` rows.
108///
109/// The element type is inferred from `T` (e.g. `f32` or `f64`). This is the shape expected for
110/// the RHS of a database-vs-query [`CosineSimilarity`] scan: the `ScalarFnArray` contract
111/// requires both children to have the same length, so rather than hand-rolling a 1-row input we
112/// broadcast the query across the whole database.
113///
114/// # Errors
115///
116/// Returns an error if the [`Vector`] extension dtype rejects the constructed storage dtype.
117pub fn build_constant_query_vector<T: NativePType + Into<PValue>>(
118    query: &[T],
119    num_rows: usize,
120) -> VortexResult<ArrayRef> {
121    let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable);
122
123    let children: Vec<Scalar> = query
124        .iter()
125        .map(|&v| Scalar::primitive(v, Nullability::NonNullable))
126        .collect();
127    let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
128
129    let storage = ConstantArray::new(storage_scalar, num_rows).into_array();
130
131    let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
132    Ok(ExtensionArray::new(ext_dtype, storage).into_array())
133}
134
135/// Build the lazy similarity-search expression tree for a prepared database array and a
136/// single query vector.
137///
138/// The returned array is a lazy boolean expression of length `data.len()` whose position `i`
139/// is `true` iff `cosine_similarity(data[i], query) > threshold`. Executing it into a
140/// [`BoolArray`](vortex_array::arrays::BoolArray) runs the full scan.
141///
142/// The tree shape is:
143///
144/// ```text
145/// Binary(Gt, [
146///     CosineSimilarity([data, ConstantArray(query_vec, n)]),
147///     ConstantArray(threshold, n),
148/// ])
149/// ```
150///
151/// The element type is inferred from `T` and must match the element type of `data`'s
152/// [`Vector`] extension dtype.
153///
154/// This function performs no execution; it is safe to call inside a benchmark setup closure.
155///
156/// # Errors
157///
158/// Returns an error if `query` has a length incompatible with `data`'s vector dimension, or
159/// if any of the intermediate array constructors fails.
160pub fn build_similarity_search_tree<T: NativePType + Into<PValue>>(
161    data: ArrayRef,
162    query: &[T],
163    threshold: T,
164) -> VortexResult<ArrayRef> {
165    let num_rows = data.len();
166    let query_vec = build_constant_query_vector(query, num_rows)?;
167
168    let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array();
169
170    let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable);
171    let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();
172
173    cosine.binary(threshold_array, Operator::Gt)
174}
175
176#[cfg(test)]
177mod tests {
178    use vortex_array::ArrayRef;
179    use vortex_array::IntoArray;
180    use vortex_array::VortexSessionExecute;
181    use vortex_array::arrays::BoolArray;
182    use vortex_array::arrays::Extension;
183    use vortex_array::arrays::ExtensionArray;
184    use vortex_array::arrays::FixedSizeListArray;
185    use vortex_array::arrays::PrimitiveArray;
186    use vortex_array::arrays::bool::BoolArrayExt;
187    use vortex_array::dtype::extension::ExtDType;
188    use vortex_array::extension::EmptyMetadata;
189    use vortex_array::session::ArraySession;
190    use vortex_array::validity::Validity;
191    use vortex_buffer::BufferMut;
192    use vortex_error::VortexResult;
193    use vortex_session::VortexSession;
194
195    use super::build_constant_query_vector;
196    use super::build_similarity_search_tree;
197    use super::compress_turboquant;
198    use crate::vector::Vector;
199
200    /// Build a `Vector<DIM, f32>` extension array from a flat f32 slice. Each contiguous
201    /// group of `DIM` values becomes one row.
202    fn vector_array(dim: u32, values: &[f32]) -> VortexResult<ArrayRef> {
203        let dim_usize = dim as usize;
204        assert_eq!(values.len() % dim_usize, 0);
205        let num_rows = values.len() / dim_usize;
206
207        let mut buf = BufferMut::<f32>::with_capacity(values.len());
208        for &v in values {
209            buf.push(v);
210        }
211        let elements = PrimitiveArray::new::<f32>(buf.freeze(), Validity::NonNullable);
212        let fsl = FixedSizeListArray::try_new(
213            elements.into_array(),
214            dim,
215            Validity::NonNullable,
216            num_rows,
217        )?;
218
219        let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
220        Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
221    }
222
223    fn test_session() -> VortexSession {
224        VortexSession::empty().with::<ArraySession>()
225    }
226
227    #[test]
228    fn constant_query_vector_has_vector_extension_dtype() -> VortexResult<()> {
229        let query = vec![1.0f32, 0.0, 0.0, 0.0];
230        let rhs = build_constant_query_vector(&query, 5)?;
231
232        assert_eq!(rhs.len(), 5);
233        assert!(rhs.as_opt::<Extension>().is_some());
234        Ok(())
235    }
236
237    #[test]
238    fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> {
239        // 4 rows of 3-dim vectors; the first and last match the query [1, 0, 0].
240        let data = vector_array(
241            3,
242            &[
243                1.0, 0.0, 0.0, //
244                0.0, 1.0, 0.0, //
245                0.0, 0.0, 1.0, //
246                1.0, 0.0, 0.0, //
247            ],
248        )?;
249        let query = [1.0f32, 0.0, 0.0];
250
251        let tree = build_similarity_search_tree(data, &query, 0.5)?;
252        let mut ctx = test_session().create_execution_ctx();
253        let result: BoolArray = tree.execute(&mut ctx)?;
254
255        let bits = result.to_bit_buffer();
256        assert_eq!(bits.len(), 4);
257        assert!(bits.value(0));
258        assert!(!bits.value(1));
259        assert!(!bits.value(2));
260        assert!(bits.value(3));
261        Ok(())
262    }
263
264    #[test]
265    fn turboquant_roundtrip_preserves_ranking() -> VortexResult<()> {
266        // Build 6 rows of 128-dim vectors where row 0 is highly correlated with the query.
267        // TurboQuant should preserve the "row 0 is the best match" ordering.
268        const DIM: u32 = 128;
269        const NUM_ROWS: usize = 6;
270
271        let mut values = Vec::<f32>::with_capacity(NUM_ROWS * DIM as usize);
272        let query: Vec<f32> = (0..DIM as usize)
273            .map(|i| ((i as f32) * 0.017).sin())
274            .collect();
275
276        // Row 0: identical to query (cosine=1.0)
277        values.extend_from_slice(&query);
278        // Row 1: query + noise
279        for (i, q) in query.iter().enumerate() {
280            values.push(q + 0.05 * ((i as f32) * 0.03).cos());
281        }
282        // Rows 2..6: unrelated patterns
283        for row in 2..NUM_ROWS {
284            for i in 0..DIM as usize {
285                values.push(((row as f32 * 1.3 + i as f32) * 0.07).sin());
286            }
287        }
288
289        let data = vector_array(DIM, &values)?;
290        let mut ctx = test_session().create_execution_ctx();
291        let compressed = compress_turboquant(data, &mut ctx)?;
292        assert_eq!(compressed.len(), NUM_ROWS);
293
294        // Build a tree with a low threshold so row 0 (cosine=1.0 exact) matches.
295        let tree = build_similarity_search_tree(compressed, &query, 0.95)?;
296        let result: BoolArray = tree.execute(&mut ctx)?;
297        let bits = result.to_bit_buffer();
298        assert_eq!(bits.len(), NUM_ROWS);
299        assert!(
300            bits.value(0),
301            "row 0 (identical to query) must match at threshold 0.95 even after TurboQuant"
302        );
303        Ok(())
304    }
305}