vortex_tensor/vector_search.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Reusable helpers for building brute-force vector similarity search expressions over
5//! [`Vector`] extension arrays.
6//!
7//! [`build_similarity_search_tree`] broadcasts the query into the shape expected by
8//! [`CosineSimilarity`] via `Vector::constant_array` and returns a lazy
9//! `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression. The caller is responsible
10//! for preparing `data` (e.g. by running it through [`turboquant_encode`]); this builder does not
11//! compress.
12//!
13//! Executing the tree into a [`BoolArray`] yields one boolean per row indicating whether that row's
14//! cosine similarity to the query exceeds `threshold`.
15//!
16//! # Example
17//!
18//! ```ignore
19//! use vortex_array::{ArrayRef, VortexSessionExecute};
20//! use vortex_array::arrays::BoolArray;
21//! use vortex_session::VortexSession;
22//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode};
23//! use vortex_tensor::vector_search::build_similarity_search_tree;
24//!
25//! fn run(session: &VortexSession, data: ArrayRef, query: &[f32]) -> anyhow::Result<()> {
26//! let mut ctx = session.create_execution_ctx();
27//! let data = turboquant_encode(data, &TurboQuantConfig::default(), &mut ctx)?;
28//! let tree = build_similarity_search_tree(data, query, 0.8)?;
29//! let _matches: BoolArray = tree.execute(&mut ctx)?;
30//! Ok(())
31//! }
32//! ```
33//!
34//! [`Vector`]: crate::vector::Vector
35//! [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity
36//! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode
37//! [`BoolArray`]: vortex_array::arrays::BoolArray
38
39use vortex_array::ArrayRef;
40use vortex_array::IntoArray;
41use vortex_array::arrays::ConstantArray;
42use vortex_array::builtins::ArrayBuiltins;
43use vortex_array::dtype::NativePType;
44use vortex_array::dtype::Nullability;
45use vortex_array::scalar::PValue;
46use vortex_array::scalar::Scalar;
47use vortex_array::scalar_fn::fns::operators::Operator;
48use vortex_error::VortexResult;
49
50use crate::scalar_fns::cosine_similarity::CosineSimilarity;
51use crate::types::vector::Vector;
52
53/// Build the lazy similarity-search expression tree for a prepared database array and a
54/// single query vector.
55///
56/// The returned array is a lazy boolean expression of length `data.len()` whose position `i`
57/// is `true` iff `cosine_similarity(data[i], query) > threshold`. Executing it into a
58/// [`BoolArray`](vortex_array::arrays::BoolArray) runs the full scan.
59///
60/// The tree shape is:
61///
62/// ```text
63/// Binary(Gt, [
64/// CosineSimilarity([data, ConstantArray(query_vec, n)]),
65/// ConstantArray(threshold, n),
66/// ])
67/// ```
68///
69/// The element type is inferred from `T` and must match the element type of `data`'s
70/// [`Vector`] extension dtype.
71///
72/// This function performs no execution; it is safe to call inside a benchmark setup closure.
73///
74/// # Errors
75///
76/// Returns an error if `query` has a length incompatible with `data`'s vector dimension, or
77/// if any of the intermediate array constructors fails.
78pub fn build_similarity_search_tree<T: NativePType + Into<PValue>>(
79 data: ArrayRef,
80 query: &[T],
81 threshold: T,
82) -> VortexResult<ArrayRef> {
83 let num_rows = data.len();
84 let query_vec = Vector::constant_array(query, num_rows)?;
85
86 let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array();
87
88 let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable);
89 let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();
90
91 cosine.binary(threshold_array, Operator::Gt)
92}
93
94#[cfg(test)]
95mod tests {
96 use vortex_array::VortexSessionExecute;
97 use vortex_array::arrays::BoolArray;
98 use vortex_array::arrays::bool::BoolArrayExt;
99 use vortex_error::VortexResult;
100
101 use super::build_similarity_search_tree;
102 use crate::encodings::turboquant::TurboQuantConfig;
103 use crate::encodings::turboquant::turboquant_encode;
104 use crate::tests::SESSION;
105 use crate::utils::test_helpers::vector_array;
106
107 #[test]
108 fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> {
109 // 4 rows of 3-dim vectors; the first and last match the query [1, 0, 0].
110 let data = vector_array(
111 3,
112 &[
113 1.0f32, 0.0, 0.0, //
114 0.0, 1.0, 0.0, //
115 0.0, 0.0, 1.0, //
116 1.0, 0.0, 0.0, //
117 ],
118 )?;
119 let query = [1.0f32, 0.0, 0.0];
120
121 let tree = build_similarity_search_tree(data, &query, 0.5)?;
122 let mut ctx = SESSION.create_execution_ctx();
123 let result: BoolArray = tree.execute(&mut ctx)?;
124
125 let bits = result.to_bit_buffer();
126 assert_eq!(bits.len(), 4);
127 assert!(bits.value(0));
128 assert!(!bits.value(1));
129 assert!(!bits.value(2));
130 assert!(bits.value(3));
131 Ok(())
132 }
133
134 #[test]
135 fn turboquant_roundtrip_preserves_ranking() -> VortexResult<()> {
136 // Build 6 rows of 128-dim vectors where row 0 is highly correlated with the query.
137 // TurboQuant should preserve the "row 0 is the best match" ordering.
138 const DIM: u32 = 128;
139 const NUM_ROWS: usize = 6;
140
141 let mut values = Vec::<f32>::with_capacity(NUM_ROWS * DIM as usize);
142 let query: Vec<f32> = (0..DIM as usize)
143 .map(|i| ((i as f32) * 0.017).sin())
144 .collect();
145
146 // Row 0: identical to query (cosine=1.0)
147 values.extend_from_slice(&query);
148 // Row 1: query + noise
149 for (i, q) in query.iter().enumerate() {
150 values.push(q + 0.05 * ((i as f32) * 0.03).cos());
151 }
152 // Rows 2..6: unrelated patterns
153 for row in 2..NUM_ROWS {
154 for i in 0..DIM as usize {
155 values.push(((row as f32 * 1.3 + i as f32) * 0.07).sin());
156 }
157 }
158
159 let data = vector_array(DIM, &values)?;
160 let mut ctx = SESSION.create_execution_ctx();
161 let compressed = turboquant_encode(data, &TurboQuantConfig::default(), &mut ctx)?;
162 assert_eq!(compressed.len(), NUM_ROWS);
163
164 // Build a tree with a low threshold so row 0 (cosine=1.0 exact) matches.
165 let tree = build_similarity_search_tree(compressed, &query, 0.95)?;
166 let result: BoolArray = tree.execute(&mut ctx)?;
167 let bits = result.to_bit_buffer();
168 assert_eq!(bits.len(), NUM_ROWS);
169 assert!(
170 bits.value(0),
171 "row 0 (identical to query) must match at threshold 0.95 even after TurboQuant"
172 );
173 Ok(())
174 }
175}