swiftide_integrations/duckdb/
retrieve.rs1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3use swiftide_core::{
4 querying::{
5 search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
6 states, Document, Query,
7 },
8 Retrieve,
9};
10
11use super::Duckdb;
12
13#[async_trait]
14impl Retrieve<SimilaritySingleEmbedding> for Duckdb {
15 async fn retrieve(
16 &self,
17 search_strategy: &SimilaritySingleEmbedding,
18 query: Query<states::Pending>,
19 ) -> Result<Query<states::Retrieved>> {
20 let Some(embedding) = query.embedding.as_ref() else {
21 return Err(anyhow::Error::msg("Missing embedding in query state"));
22 };
23
24 let table_name = &self.table_name;
25
26 let (field_name, embedding_size) = self
28 .vectors
29 .iter()
30 .next()
31 .context("No vectors configured")?;
32
33 let limit = search_strategy.top_k();
34
35 let sql = format!(
39 "SELECT uuid, chunk, path FROM {table_name}\n
40 ORDER BY array_distance({field_name}, ARRAY[{}]::FLOAT[{embedding_size}])\n
41 LIMIT {limit}",
42 embedding
43 .iter()
44 .map(ToString::to_string)
45 .collect::<Vec<_>>()
46 .join(",")
47 );
48
49 tracing::trace!("[duckdb] Executing query: {}", sql);
50
51 let conn = self.connection().lock().unwrap();
52
53 let mut stmt = conn
54 .prepare(&sql)
55 .context("Failed to prepare duckdb statement for persist")?;
56
57 tracing::trace!("[duckdb] Retrieving documents");
58
59 let documents = stmt
60 .query_map([], |row| {
61 Ok(Document::builder()
62 .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
63 .content(row.get::<_, String>(1)?)
64 .build()
65 .expect("Failed to build document; should never happen"))
66 })
67 .context("failed to query for documents")?
68 .collect::<Result<Vec<Document>, _>>()
69 .context("failed to build documents")?;
70
71 tracing::debug!("[duckdb] Retrieved documents");
72 Ok(query.retrieved_documents(documents))
73 }
74}
75
76#[async_trait]
77impl Retrieve<CustomStrategy<String>> for Duckdb {
78 async fn retrieve(
79 &self,
80 search_strategy: &CustomStrategy<String>,
81 query: Query<states::Pending>,
82 ) -> Result<Query<states::Retrieved>> {
83 let sql = search_strategy
84 .build_query(&query)
85 .await
86 .context("Failed to build query")?;
87
88 tracing::debug!("[duckdb] Executing query: {}", sql);
89
90 let conn = self.connection().lock().unwrap();
91 let mut stmt = conn
92 .prepare(&sql)
93 .context("Failed to prepare duckdb statement for persist")?;
94
95 tracing::debug!("[duckdb] Prepared statement");
96
97 let documents = stmt
98 .query_map([], |row| {
99 Ok(Document::builder()
100 .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
101 .content(row.get::<_, String>(1)?)
102 .build()
103 .expect("Failed to build document; should never happen"))
104 })
105 .context("failed to query for documents")?
106 .collect::<Result<Vec<Document>, _>>()
107 .context("failed to build documents")?;
108
109 tracing::debug!("[duckdb] Retrieved documents");
110
111 Ok(query.retrieved_documents(documents))
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use indexing::{EmbeddedField, Node};
118 use swiftide_core::{indexing, Persist as _};
119
120 use super::*;
121
122 #[test_log::test(tokio::test)]
123 async fn test_duckdb_retrieving_documents() {
124 let client = Duckdb::builder()
125 .connection(duckdb::Connection::open_in_memory().unwrap())
126 .table_name("test".to_string())
127 .with_vector(EmbeddedField::Combined, 3)
128 .build()
129 .unwrap();
130
131 let node = Node::new("Hello duckdb!")
132 .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
133 .to_owned();
134
135 client.setup().await.unwrap();
136 client.store(node.clone()).await.unwrap();
137
138 tracing::info!("Stored node");
139
140 let query = Query::<states::Pending>::builder()
141 .embedding(vec![1.0, 2.0, 3.0])
142 .original("Some query")
143 .build()
144 .unwrap();
145
146 let result = client
147 .retrieve(&SimilaritySingleEmbedding::default(), query)
148 .await
149 .unwrap();
150
151 assert_eq!(result.documents().len(), 1);
152 let document = result.documents().first().unwrap();
153
154 assert_eq!(document.content(), "Hello duckdb!");
155 assert_eq!(
156 document.metadata().get("id").unwrap().as_str(),
157 Some(node.id().to_string().as_str())
158 );
159 }
160}