swiftide_integrations/duckdb/
retrieve.rs

1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3use swiftide_core::{
4    querying::{
5        search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
6        states, Document, Query,
7    },
8    Retrieve,
9};
10
11use super::Duckdb;
12
13#[async_trait]
14impl Retrieve<SimilaritySingleEmbedding> for Duckdb {
15    async fn retrieve(
16        &self,
17        search_strategy: &SimilaritySingleEmbedding,
18        query: Query<states::Pending>,
19    ) -> Result<Query<states::Retrieved>> {
20        let Some(embedding) = query.embedding.as_ref() else {
21            return Err(anyhow::Error::msg("Missing embedding in query state"));
22        };
23
24        let table_name = &self.table_name;
25
26        // Silently ignores multiple vector fields
27        let (field_name, embedding_size) = self
28            .vectors
29            .iter()
30            .next()
31            .context("No vectors configured")?;
32
33        let limit = search_strategy.top_k();
34
35        // Ideally it should be a prepared statement, where only the new parameters lead to extra
36        // allocations. This is possible in 1.2.1, but that version is still broken for VSS via
37        // Rust.
38        let sql = format!(
39            "SELECT uuid, chunk, path FROM {table_name}\n
40            ORDER BY array_distance({field_name}, ARRAY[{}]::FLOAT[{embedding_size}])\n
41            LIMIT {limit}",
42            embedding
43                .iter()
44                .map(ToString::to_string)
45                .collect::<Vec<_>>()
46                .join(",")
47        );
48
49        tracing::trace!("[duckdb] Executing query: {}", sql);
50
51        let conn = self.connection().lock().unwrap();
52
53        let mut stmt = conn
54            .prepare(&sql)
55            .context("Failed to prepare duckdb statement for persist")?;
56
57        tracing::trace!("[duckdb] Retrieving documents");
58
59        let documents = stmt
60            .query_map([], |row| {
61                Ok(Document::builder()
62                    .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
63                    .content(row.get::<_, String>(1)?)
64                    .build()
65                    .expect("Failed to build document; should never happen"))
66            })
67            .context("failed to query for documents")?
68            .collect::<Result<Vec<Document>, _>>()
69            .context("failed to build documents")?;
70
71        tracing::debug!("[duckdb] Retrieved documents");
72        Ok(query.retrieved_documents(documents))
73    }
74}
75
76#[async_trait]
77impl Retrieve<CustomStrategy<String>> for Duckdb {
78    async fn retrieve(
79        &self,
80        search_strategy: &CustomStrategy<String>,
81        query: Query<states::Pending>,
82    ) -> Result<Query<states::Retrieved>> {
83        let sql = search_strategy
84            .build_query(&query)
85            .await
86            .context("Failed to build query")?;
87
88        tracing::debug!("[duckdb] Executing query: {}", sql);
89
90        let conn = self.connection().lock().unwrap();
91        let mut stmt = conn
92            .prepare(&sql)
93            .context("Failed to prepare duckdb statement for persist")?;
94
95        tracing::debug!("[duckdb] Prepared statement");
96
97        let documents = stmt
98            .query_map([], |row| {
99                Ok(Document::builder()
100                    .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
101                    .content(row.get::<_, String>(1)?)
102                    .build()
103                    .expect("Failed to build document; should never happen"))
104            })
105            .context("failed to query for documents")?
106            .collect::<Result<Vec<Document>, _>>()
107            .context("failed to build documents")?;
108
109        tracing::debug!("[duckdb] Retrieved documents");
110
111        Ok(query.retrieved_documents(documents))
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use indexing::{EmbeddedField, Node};
118    use swiftide_core::{indexing, Persist as _};
119
120    use super::*;
121
122    #[test_log::test(tokio::test)]
123    async fn test_duckdb_retrieving_documents() {
124        let client = Duckdb::builder()
125            .connection(duckdb::Connection::open_in_memory().unwrap())
126            .table_name("test".to_string())
127            .with_vector(EmbeddedField::Combined, 3)
128            .build()
129            .unwrap();
130
131        let node = Node::new("Hello duckdb!")
132            .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
133            .to_owned();
134
135        client.setup().await.unwrap();
136        client.store(node.clone()).await.unwrap();
137
138        tracing::info!("Stored node");
139
140        let query = Query::<states::Pending>::builder()
141            .embedding(vec![1.0, 2.0, 3.0])
142            .original("Some query")
143            .build()
144            .unwrap();
145
146        let result = client
147            .retrieve(&SimilaritySingleEmbedding::default(), query)
148            .await
149            .unwrap();
150
151        assert_eq!(result.documents().len(), 1);
152        let document = result.documents().first().unwrap();
153
154        assert_eq!(document.content(), "Hello duckdb!");
155        assert_eq!(
156            document.metadata().get("id").unwrap().as_str(),
157            Some(node.id().to_string().as_str())
158        );
159    }
160}