swiftide_integrations/duckdb/
retrieve.rs1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3use swiftide_core::{
4 Retrieve,
5 indexing::Chunk,
6 querying::{
7 Document, Query,
8 search_strategies::{CustomStrategy, HybridSearch, SimilaritySingleEmbedding},
9 states,
10 },
11};
12
13use super::Duckdb;
14
15#[async_trait]
16impl<T: Chunk> Retrieve<SimilaritySingleEmbedding> for Duckdb<T> {
17 async fn retrieve(
18 &self,
19 search_strategy: &SimilaritySingleEmbedding,
20 query: Query<states::Pending>,
21 ) -> Result<Query<states::Retrieved>> {
22 let Some(embedding) = query.embedding.as_ref() else {
23 return Err(anyhow::Error::msg("Missing embedding in query state"));
24 };
25
26 let table_name = &self.table_name;
27
28 let (field_name, embedding_size) = self
30 .vectors
31 .iter()
32 .next()
33 .context("No vectors configured")?;
34
35 let limit = search_strategy.top_k();
36
37 let sql = format!(
41 "SELECT uuid, chunk, path FROM {table_name}\n
42 ORDER BY array_distance({field_name}, ARRAY[{}]::FLOAT[{embedding_size}])\n
43 LIMIT {limit}",
44 embedding
45 .iter()
46 .map(ToString::to_string)
47 .collect::<Vec<_>>()
48 .join(",")
49 );
50
51 tracing::trace!("[duckdb] Executing query: {}", sql);
52
53 let conn = self.connection().lock().unwrap();
54
55 let mut stmt = conn
56 .prepare(&sql)
57 .context("Failed to prepare duckdb statement for persist")?;
58
59 tracing::trace!("[duckdb] Retrieving documents");
60
61 let documents = stmt
62 .query_map([], |row| {
63 Ok(Document::builder()
64 .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
65 .content(row.get::<_, String>(1)?)
66 .build()
67 .expect("Failed to build document; should never happen"))
68 })
69 .context("failed to query for documents")?
70 .collect::<Result<Vec<Document>, _>>()
71 .context("failed to build documents")?;
72
73 tracing::debug!("[duckdb] Retrieved documents");
74 Ok(query.retrieved_documents(documents))
75 }
76}
77
78#[async_trait]
79impl<T: Chunk> Retrieve<CustomStrategy<String>> for Duckdb<T> {
80 async fn retrieve(
81 &self,
82 search_strategy: &CustomStrategy<String>,
83 query: Query<states::Pending>,
84 ) -> Result<Query<states::Retrieved>> {
85 let sql = search_strategy
86 .build_query(&query)
87 .await
88 .context("Failed to build query")?;
89
90 tracing::debug!("[duckdb] Executing query: {}", sql);
91
92 let conn = self.connection().lock().unwrap();
93 let mut stmt = conn
94 .prepare(&sql)
95 .context("Failed to prepare duckdb statement for persist")?;
96
97 tracing::debug!("[duckdb] Prepared statement");
98
99 let documents = stmt
100 .query_map([], |row| {
101 Ok(Document::builder()
102 .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
103 .content(row.get::<_, String>(1)?)
104 .build()
105 .expect("Failed to build document; should never happen"))
106 })
107 .context("failed to query for documents")?
108 .collect::<Result<Vec<Document>, _>>()
109 .context("failed to build documents")?;
110
111 tracing::debug!("[duckdb] Retrieved documents");
112
113 Ok(query.retrieved_documents(documents))
114 }
115}
116
117#[async_trait]
118impl<T: Chunk> Retrieve<HybridSearch> for Duckdb<T> {
119 async fn retrieve(
120 &self,
121 search_strategy: &HybridSearch,
122 query: Query<states::Pending>,
123 ) -> Result<Query<states::Retrieved>> {
124 let Some(embedding) = query.embedding.as_ref() else {
125 return Err(anyhow::Error::msg("Missing embedding in query state"));
126 };
127
128 let sql = self
129 .hybrid_query_sql(search_strategy, query.current(), embedding)
130 .context("Failed to build query")?;
131
132 tracing::debug!("[duckdb] Executing query: {}", sql);
133
134 let conn = self.connection().lock().unwrap();
135 let mut stmt = conn
136 .prepare(&sql)
137 .context("Failed to prepare duckdb statement for persist")?;
138
139 tracing::debug!("[duckdb] Prepared statement");
140
141 let documents = stmt
142 .query_map([], |row| {
144 Ok(Document::builder()
145 .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
146 .content(row.get::<_, String>(1)?)
147 .build()
148 .expect("Failed to build document; should never happen"))
149 })
150 .context("failed to query for documents")?
151 .collect::<Result<Vec<Document>, _>>()
152 .context("failed to build documents")?;
153
154 tracing::debug!("[duckdb] Retrieved documents");
155
156 Ok(query.retrieved_documents(documents))
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use indexing::{EmbeddedField, TextNode};
163 use swiftide_core::{Persist as _, indexing};
164
165 use super::*;
166
167 #[test_log::test(tokio::test)]
168 async fn test_duckdb_retrieving_documents() {
169 let client = Duckdb::builder()
170 .connection(duckdb::Connection::open_in_memory().unwrap())
171 .table_name("test".to_string())
172 .with_vector(EmbeddedField::Combined, 3)
173 .build()
174 .unwrap();
175
176 let node = TextNode::new("Hello duckdb!")
177 .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
178 .to_owned();
179
180 client.setup().await.unwrap();
181 client.store(node.clone()).await.unwrap();
182
183 tracing::info!("Stored node");
184
185 let query = Query::<states::Pending>::builder()
186 .embedding(vec![1.0, 2.0, 3.0])
187 .original("Some query")
188 .build()
189 .unwrap();
190
191 let result = client
192 .retrieve(&SimilaritySingleEmbedding::default(), query)
193 .await
194 .unwrap();
195
196 assert_eq!(result.documents().len(), 1);
197 let document = result.documents().first().unwrap();
198
199 assert_eq!(document.content(), "Hello duckdb!");
200 assert_eq!(
201 document.metadata().get("id").unwrap().as_str(),
202 Some(node.id().to_string().as_str())
203 );
204 }
205
206 #[test_log::test(tokio::test)]
207 async fn test_duckdb_retrieving_documents_hybrid() {
208 let client = Duckdb::builder()
209 .connection(duckdb::Connection::open_in_memory().unwrap())
210 .table_name("test".to_string())
211 .with_vector(EmbeddedField::Combined, 3)
212 .build()
213 .unwrap();
214
215 let node = TextNode::new("Hello duckdb!")
216 .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
217 .to_owned();
218
219 client.setup().await.unwrap();
220 client.store(node.clone()).await.unwrap();
221
222 tracing::info!("Stored node");
223
224 let query = Query::<states::Pending>::builder()
225 .embedding(vec![1.0, 2.0, 3.0])
226 .original("Some query")
227 .build()
228 .unwrap();
229
230 let result = client
231 .retrieve(&HybridSearch::default(), query)
232 .await
233 .unwrap();
234
235 assert_eq!(result.documents().len(), 1);
236 let document = result.documents().first().unwrap();
237
238 assert_eq!(document.content(), "Hello duckdb!");
239 assert_eq!(
240 document.metadata().get("id").unwrap().as_str(),
241 Some(node.id().to_string().as_str())
242 );
243 }
244}