vortex_tensor/vector_search.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Reusable helpers for building brute-force vector similarity search expressions over
5//! [`Vector`] extension arrays.
6//!
7//! [`build_similarity_search_tree`] broadcasts the query into the shape expected by
8//! [`CosineSimilarity`] via `Vector::constant_array` and returns a lazy
9//! `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression. The caller is responsible
10//! for preparing `data` (e.g. by compressing it beforehand); this builder does not compress.
11//!
12//! Executing the tree into a [`BoolArray`] yields one boolean per row indicating whether that row's
13//! cosine similarity to the query exceeds `threshold`.
14//!
15//! # Example
16//!
17//! ```ignore
18//! use vortex_array::{ArrayRef, VortexSessionExecute};
19//! use vortex_array::arrays::BoolArray;
20//! use vortex_session::VortexSession;
21//! use vortex_tensor::vector_search::build_similarity_search_tree;
22//!
23//! fn run(session: &VortexSession, data: ArrayRef, query: &[f32]) -> anyhow::Result<()> {
24//! let mut ctx = session.create_execution_ctx();
25//! let tree = build_similarity_search_tree(data, query, 0.8)?;
26//! let _matches: BoolArray = tree.execute(&mut ctx)?;
27//! Ok(())
28//! }
29//! ```
30//!
31//! [`Vector`]: crate::vector::Vector
32//! [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity
33//! [`BoolArray`]: vortex_array::arrays::BoolArray
34
35use vortex_array::ArrayRef;
36use vortex_array::IntoArray;
37use vortex_array::arrays::ConstantArray;
38use vortex_array::builtins::ArrayBuiltins;
39use vortex_array::dtype::NativePType;
40use vortex_array::dtype::Nullability;
41use vortex_array::scalar::PValue;
42use vortex_array::scalar::Scalar;
43use vortex_array::scalar_fn::fns::operators::Operator;
44use vortex_error::VortexResult;
45
46use crate::scalar_fns::cosine_similarity::CosineSimilarity;
47use crate::types::vector::Vector;
48
49/// Build the lazy similarity-search expression tree for a prepared database array and a
50/// single query vector.
51///
52/// The returned array is a lazy boolean expression of length `data.len()` whose position `i`
53/// is `true` iff `cosine_similarity(data[i], query) > threshold`. Executing it into a
54/// [`BoolArray`](vortex_array::arrays::BoolArray) runs the full scan.
55///
56/// The tree shape is:
57///
58/// ```text
59/// Binary(Gt, [
60/// CosineSimilarity([data, ConstantArray(query_vec, n)]),
61/// ConstantArray(threshold, n),
62/// ])
63/// ```
64///
65/// The element type is inferred from `T` and must match the element type of `data`'s
66/// [`Vector`] extension dtype.
67///
68/// This function performs no execution; it is safe to call inside a benchmark setup closure.
69///
70/// # Errors
71///
72/// Returns an error if `query` has a length incompatible with `data`'s vector dimension, or
73/// if any of the intermediate array constructors fails.
74pub fn build_similarity_search_tree<T: NativePType + Into<PValue>>(
75 data: ArrayRef,
76 query: &[T],
77 threshold: T,
78) -> VortexResult<ArrayRef> {
79 let num_rows = data.len();
80 let query_vec = Vector::constant_array(query, num_rows)?;
81
82 let cosine = CosineSimilarity::try_new_array(data, query_vec)?.into_array();
83
84 let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable);
85 let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();
86
87 cosine.binary(threshold_array, Operator::Gt)
88}
89
90#[cfg(test)]
91mod tests {
92 use vortex_array::VortexSessionExecute;
93 use vortex_array::arrays::BoolArray;
94 use vortex_array::arrays::bool::BoolArrayExt;
95 use vortex_error::VortexResult;
96
97 use super::build_similarity_search_tree;
98 use crate::tests::SESSION;
99 use crate::utils::test_helpers::vector_array;
100
101 #[test]
102 fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> {
103 // 4 rows of 3-dim vectors; the first and last match the query [1, 0, 0].
104 let data = vector_array(
105 3,
106 &[
107 1.0f32, 0.0, 0.0, //
108 0.0, 1.0, 0.0, //
109 0.0, 0.0, 1.0, //
110 1.0, 0.0, 0.0, //
111 ],
112 )?;
113 let query = [1.0f32, 0.0, 0.0];
114
115 let tree = build_similarity_search_tree(data, &query, 0.5)?;
116 let mut ctx = SESSION.create_execution_ctx();
117 let result: BoolArray = tree.execute(&mut ctx)?;
118
119 let bits = result.to_bit_buffer();
120 assert_eq!(bits.len(), 4);
121 assert!(bits.value(0));
122 assert!(!bits.value(1));
123 assert!(!bits.value(2));
124 assert!(bits.value(3));
125 Ok(())
126 }
127}