swiftide_integrations/lancedb/
retrieve.rs

1use anyhow::Result;
2use arrow_array::{RecordBatch, StringArray};
3use async_trait::async_trait;
4use futures_util::TryStreamExt;
5use itertools::Itertools;
6use lancedb::query::{ExecutableQuery, QueryBase};
7use swiftide_core::{
8    Retrieve,
9    document::Document,
10    indexing::Metadata,
11    querying::{
12        Query,
13        search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
14        states,
15    },
16};
17
18use super::{FieldConfig, LanceDB};
19
20/// Implement the `Retrieve` trait for `SimilaritySingleEmbedding` search strategy.
21///
22/// Can be used in the query pipeline to retrieve documents from `LanceDB`.
23///
24/// Supports filters as strings. Refer to the `LanceDB` documentation for the format.
25#[async_trait]
26impl Retrieve<SimilaritySingleEmbedding<String>> for LanceDB {
27    #[tracing::instrument]
28    async fn retrieve(
29        &self,
30        search_strategy: &SimilaritySingleEmbedding<String>,
31        query: Query<states::Pending>,
32    ) -> Result<Query<states::Retrieved>> {
33        let Some(embedding) = &query.embedding else {
34            anyhow::bail!("No embedding for query")
35        };
36
37        let table = self
38            .get_connection()
39            .await?
40            .open_table(&self.table_name)
41            .execute()
42            .await?;
43
44        let vector_fields = self
45            .fields
46            .iter()
47            .filter(|field| matches!(field, FieldConfig::Vector(_)))
48            .collect_vec();
49
50        if vector_fields.is_empty() || vector_fields.len() > 1 {
51            anyhow::bail!("Zero or multiple vector fields configured in schema")
52        }
53
54        let column_name = vector_fields.first().map(|v| v.field_name()).unwrap();
55
56        let mut query_builder = table
57            .query()
58            .nearest_to(embedding.as_slice())?
59            .column(&column_name)
60            .limit(usize::try_from(search_strategy.top_k())?);
61
62        if let Some(filter) = &search_strategy.filter() {
63            query_builder = query_builder.only_if(filter);
64        }
65
66        let batches = query_builder
67            .execute()
68            .await?
69            .try_collect::<Vec<_>>()
70            .await?;
71
72        let documents = Self::retrieve_from_record_batches(batches.as_slice());
73
74        Ok(query.retrieved_documents(documents))
75    }
76}
77
78#[async_trait]
79impl Retrieve<SimilaritySingleEmbedding> for LanceDB {
80    async fn retrieve(
81        &self,
82        search_strategy: &SimilaritySingleEmbedding,
83        query: Query<states::Pending>,
84    ) -> Result<Query<states::Retrieved>> {
85        Retrieve::<SimilaritySingleEmbedding<String>>::retrieve(
86            self,
87            &search_strategy.into_concrete_filter::<String>(),
88            query,
89        )
90        .await
91    }
92}
93
94#[async_trait]
95impl<Q: ExecutableQuery + Send + Sync + 'static> Retrieve<CustomStrategy<Q>> for LanceDB {
96    /// Implements vector similarity search for `LanceDB` using a custom query strategy.
97    ///
98    /// # Type Parameters
99    /// * `VectorQuery` - `LanceDB`'s query type for vector similarity search
100    async fn retrieve(
101        &self,
102        search_strategy: &CustomStrategy<Q>,
103        query: Query<states::Pending>,
104    ) -> Result<Query<states::Retrieved>> {
105        // Build the custom query using both strategy and query state
106        let query_builder = search_strategy.build_query(&query).await?;
107
108        // Execute the query using the builder's built-in methods
109        let batches = query_builder
110            .execute()
111            .await?
112            .try_collect::<Vec<_>>()
113            .await?;
114
115        let documents = Self::retrieve_from_record_batches(batches.as_slice());
116
117        Ok(query.retrieved_documents(documents))
118    }
119}
120
121impl LanceDB {
122    /// Retrieves documents from Arrow `RecordBatches` by processing each row and extracting content
123    /// and metadata fields.
124    ///
125    /// The function expects a "chunk" field to contain the main document content, while all other
126    /// string fields are treated as metadata. Non-string fields are currently skipped    
127    fn retrieve_from_record_batches(batches: &[RecordBatch]) -> Vec<Document> {
128        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
129        let mut documents = Vec::with_capacity(total_rows);
130
131        let process_batch = |batch: &RecordBatch, documents: &mut Vec<Document>| {
132            for row_idx in 0..batch.num_rows() {
133                let schema = batch.schema();
134
135                let (content, metadata): (String, Option<Metadata>) = {
136                    let mut metadata = Metadata::default();
137                    let mut content = String::new();
138
139                    for (col_idx, field) in schema.as_ref().fields().iter().enumerate() {
140                        if let Some(array) =
141                            batch.column(col_idx).as_any().downcast_ref::<StringArray>()
142                        {
143                            let value = array.value(row_idx).to_string();
144
145                            if field.name() == "chunk" {
146                                content = value;
147                            } else {
148                                metadata.insert(field.name().clone(), value);
149                            }
150                        } else {
151                            // Handle other array types as necessary
152                            // TODO: Can't we just downcast to serde::Value or fail?
153                        }
154                    }
155
156                    (
157                        content,
158                        if metadata.is_empty() {
159                            None
160                        } else {
161                            Some(metadata)
162                        },
163                    )
164                };
165
166                documents.push(Document::new(content, metadata));
167            }
168        };
169
170        for batch in batches {
171            process_batch(batch, &mut documents);
172        }
173
174        documents
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use swiftide_core::{
181        Persist as _,
182        indexing::{self, EmbeddedField},
183    };
184    use temp_dir::TempDir;
185
186    use super::*;
187
188    async fn setup() -> (TempDir, LanceDB) {
189        let tempdir = TempDir::new().unwrap();
190        let lancedb = LanceDB::builder()
191            .uri(tempdir.child("lancedb").to_str().unwrap())
192            .vector_size(384)
193            .with_metadata("filter")
194            .with_vector(EmbeddedField::Combined)
195            .table_name("swiftide_test")
196            .build()
197            .unwrap();
198        lancedb.setup().await.unwrap();
199
200        (tempdir, lancedb)
201    }
202
203    #[tokio::test]
204    async fn test_retrieve_multiple_docs_and_filter() {
205        let (_guard, lancedb) = setup().await;
206
207        let nodes = vec![
208            indexing::TextNode::new("test_query1").with_metadata(("filter", "true")),
209            indexing::TextNode::new("test_query2").with_metadata(("filter", "true")),
210            indexing::TextNode::new("test_query3").with_metadata(("filter", "false")),
211        ]
212        .into_iter()
213        .map(|node| {
214            node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]);
215            node.to_owned()
216        })
217        .collect();
218
219        lancedb
220            .batch_store(nodes)
221            .await
222            .try_collect::<Vec<_>>()
223            .await
224            .unwrap();
225
226        let mut query = Query::<states::Pending>::new("test_query");
227        query.embedding = Some(vec![1.0; 384]);
228
229        let search_strategy =
230            SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string());
231        let result = lancedb
232            .retrieve(&search_strategy, query.clone())
233            .await
234            .unwrap();
235        assert_eq!(result.documents().len(), 2);
236
237        let search_strategy =
238            SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string());
239        let result = lancedb
240            .retrieve(&search_strategy, query.clone())
241            .await
242            .unwrap();
243        assert_eq!(result.documents().len(), 0);
244
245        let search_strategy = SimilaritySingleEmbedding::<()>::default();
246        let result = lancedb
247            .retrieve(&search_strategy, query.clone())
248            .await
249            .unwrap();
250        assert_eq!(result.documents().len(), 3);
251    }
252}