swiftide_integrations/lancedb/
retrieve.rs1use anyhow::Result;
2use arrow_array::{RecordBatch, StringArray};
3use async_trait::async_trait;
4use futures_util::TryStreamExt;
5use itertools::Itertools;
6use lancedb::query::{ExecutableQuery, QueryBase};
7use swiftide_core::{
8 Retrieve,
9 document::Document,
10 indexing::Metadata,
11 querying::{
12 Query,
13 search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
14 states,
15 },
16};
17
18use super::{FieldConfig, LanceDB};
19
20#[async_trait]
26impl Retrieve<SimilaritySingleEmbedding<String>> for LanceDB {
27 #[tracing::instrument]
28 async fn retrieve(
29 &self,
30 search_strategy: &SimilaritySingleEmbedding<String>,
31 query: Query<states::Pending>,
32 ) -> Result<Query<states::Retrieved>> {
33 let Some(embedding) = &query.embedding else {
34 anyhow::bail!("No embedding for query")
35 };
36
37 let table = self
38 .get_connection()
39 .await?
40 .open_table(&self.table_name)
41 .execute()
42 .await?;
43
44 let vector_fields = self
45 .fields
46 .iter()
47 .filter(|field| matches!(field, FieldConfig::Vector(_)))
48 .collect_vec();
49
50 if vector_fields.is_empty() || vector_fields.len() > 1 {
51 anyhow::bail!("Zero or multiple vector fields configured in schema")
52 }
53
54 let column_name = vector_fields.first().map(|v| v.field_name()).unwrap();
55
56 let mut query_builder = table
57 .query()
58 .nearest_to(embedding.as_slice())?
59 .column(&column_name)
60 .limit(usize::try_from(search_strategy.top_k())?);
61
62 if let Some(filter) = &search_strategy.filter() {
63 query_builder = query_builder.only_if(filter);
64 }
65
66 let batches = query_builder
67 .execute()
68 .await?
69 .try_collect::<Vec<_>>()
70 .await?;
71
72 let documents = Self::retrieve_from_record_batches(batches.as_slice());
73
74 Ok(query.retrieved_documents(documents))
75 }
76}
77
78#[async_trait]
79impl Retrieve<SimilaritySingleEmbedding> for LanceDB {
80 async fn retrieve(
81 &self,
82 search_strategy: &SimilaritySingleEmbedding,
83 query: Query<states::Pending>,
84 ) -> Result<Query<states::Retrieved>> {
85 Retrieve::<SimilaritySingleEmbedding<String>>::retrieve(
86 self,
87 &search_strategy.into_concrete_filter::<String>(),
88 query,
89 )
90 .await
91 }
92}
93
94#[async_trait]
95impl<Q: ExecutableQuery + Send + Sync + 'static> Retrieve<CustomStrategy<Q>> for LanceDB {
96 async fn retrieve(
101 &self,
102 search_strategy: &CustomStrategy<Q>,
103 query: Query<states::Pending>,
104 ) -> Result<Query<states::Retrieved>> {
105 let query_builder = search_strategy.build_query(&query).await?;
107
108 let batches = query_builder
110 .execute()
111 .await?
112 .try_collect::<Vec<_>>()
113 .await?;
114
115 let documents = Self::retrieve_from_record_batches(batches.as_slice());
116
117 Ok(query.retrieved_documents(documents))
118 }
119}
120
121impl LanceDB {
122 fn retrieve_from_record_batches(batches: &[RecordBatch]) -> Vec<Document> {
128 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
129 let mut documents = Vec::with_capacity(total_rows);
130
131 let process_batch = |batch: &RecordBatch, documents: &mut Vec<Document>| {
132 for row_idx in 0..batch.num_rows() {
133 let schema = batch.schema();
134
135 let (content, metadata): (String, Option<Metadata>) = {
136 let mut metadata = Metadata::default();
137 let mut content = String::new();
138
139 for (col_idx, field) in schema.as_ref().fields().iter().enumerate() {
140 if let Some(array) =
141 batch.column(col_idx).as_any().downcast_ref::<StringArray>()
142 {
143 let value = array.value(row_idx).to_string();
144
145 if field.name() == "chunk" {
146 content = value;
147 } else {
148 metadata.insert(field.name().clone(), value);
149 }
150 } else {
151 }
154 }
155
156 (
157 content,
158 if metadata.is_empty() {
159 None
160 } else {
161 Some(metadata)
162 },
163 )
164 };
165
166 documents.push(Document::new(content, metadata));
167 }
168 };
169
170 for batch in batches {
171 process_batch(batch, &mut documents);
172 }
173
174 documents
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use swiftide_core::{
181 Persist as _,
182 indexing::{self, EmbeddedField},
183 };
184 use temp_dir::TempDir;
185
186 use super::*;
187
188 async fn setup() -> (TempDir, LanceDB) {
189 let tempdir = TempDir::new().unwrap();
190 let lancedb = LanceDB::builder()
191 .uri(tempdir.child("lancedb").to_str().unwrap())
192 .vector_size(384)
193 .with_metadata("filter")
194 .with_vector(EmbeddedField::Combined)
195 .table_name("swiftide_test")
196 .build()
197 .unwrap();
198 lancedb.setup().await.unwrap();
199
200 (tempdir, lancedb)
201 }
202
203 #[tokio::test]
204 async fn test_retrieve_multiple_docs_and_filter() {
205 let (_guard, lancedb) = setup().await;
206
207 let nodes = vec![
208 indexing::TextNode::new("test_query1").with_metadata(("filter", "true")),
209 indexing::TextNode::new("test_query2").with_metadata(("filter", "true")),
210 indexing::TextNode::new("test_query3").with_metadata(("filter", "false")),
211 ]
212 .into_iter()
213 .map(|node| {
214 node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]);
215 node.to_owned()
216 })
217 .collect();
218
219 lancedb
220 .batch_store(nodes)
221 .await
222 .try_collect::<Vec<_>>()
223 .await
224 .unwrap();
225
226 let mut query = Query::<states::Pending>::new("test_query");
227 query.embedding = Some(vec![1.0; 384]);
228
229 let search_strategy =
230 SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string());
231 let result = lancedb
232 .retrieve(&search_strategy, query.clone())
233 .await
234 .unwrap();
235 assert_eq!(result.documents().len(), 2);
236
237 let search_strategy =
238 SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string());
239 let result = lancedb
240 .retrieve(&search_strategy, query.clone())
241 .await
242 .unwrap();
243 assert_eq!(result.documents().len(), 0);
244
245 let search_strategy = SimilaritySingleEmbedding::<()>::default();
246 let result = lancedb
247 .retrieve(&search_strategy, query.clone())
248 .await
249 .unwrap();
250 assert_eq!(result.documents().len(), 3);
251 }
252}