swiftide_integrations/duckdb/
persist.rs1use std::{borrow::Cow, path::Path};
2
3use anyhow::{Context as _, Result};
4use async_trait::async_trait;
5use duckdb::{
6 Statement, ToSql, params, params_from_iter,
7 types::{ToSqlOutput, Value},
8};
9use swiftide_core::{
10 Persist,
11 indexing::{self, Chunk, Metadata, Node},
12};
13use uuid::Uuid;
14
15use super::Duckdb;
16
17#[allow(dead_code)]
18enum TextNodeValues<'a> {
19 Uuid(Uuid),
20 Path(&'a Path),
21 Chunk(&'a str),
22 Metadata(&'a Metadata),
23 Embedding(Cow<'a, [f32]>),
24 Null,
25}
26
27impl ToSql for TextNodeValues<'_> {
28 fn to_sql(&self) -> duckdb::Result<ToSqlOutput<'_>> {
29 match self {
30 TextNodeValues::Uuid(uuid) => Ok(ToSqlOutput::Owned(uuid.to_string().into())),
31 TextNodeValues::Path(path) => Ok(path.to_string_lossy().to_string().into()),
33 TextNodeValues::Chunk(chunk) => chunk.to_sql(),
34 TextNodeValues::Metadata(_metadata) => {
35 unimplemented!("maps are not yet implemented for duckdb");
36 }
38 TextNodeValues::Embedding(vector) => {
39 let array_str = format!(
40 "[{}]",
41 vector
42 .iter()
43 .map(ToString::to_string)
44 .collect::<Vec<_>>()
45 .join(",")
46 );
47 Ok(ToSqlOutput::Owned(array_str.into()))
48 }
49 TextNodeValues::Null => Ok(ToSqlOutput::Owned(Value::Null)),
50 }
51 }
52}
53
54impl<T: Chunk + AsRef<str>> Duckdb<T> {
55 fn store_node_on_stmt(&self, stmt: &mut Statement<'_>, node: &Node<T>) -> Result<()> {
56 let mut values = vec![
57 TextNodeValues::Uuid(node.id()),
58 TextNodeValues::Chunk(node.chunk.as_ref()),
59 TextNodeValues::Path(&node.path),
60 ];
61
62 let Some(node_vectors) = &node.vectors else {
63 anyhow::bail!("Expected node to have vectors; cannot store into duckdb");
64 };
65
66 for field in self.vectors.keys() {
67 let Some(vector) = node_vectors.get(field) else {
68 anyhow::bail!("Expected vector for field {} in node", field);
69 };
70
71 values.push(TextNodeValues::Embedding(vector.into()));
72 }
73
74 stmt.execute(params_from_iter(values))
76 .context("Failed to store node")?;
77
78 Ok(())
79 }
80}
81
82#[async_trait]
83impl<T: Chunk + AsRef<str>> Persist for Duckdb<T> {
84 type Input = T;
85 type Output = T;
86
87 async fn setup(&self) -> Result<()> {
88 tracing::debug!("Setting up duckdb schema");
89
90 {
91 let conn = self.connection.lock().unwrap();
92
93 if conn
95 .query_row(&format!("SHOW {}", self.table_name()), params![], |row| {
97 row.get::<_, String>(0)
98 })
99 .is_ok()
100 {
101 tracing::debug!("Indexing table already exists, skipping creation");
102 return Ok(());
103 }
104
105 let _ = conn.execute_batch(include_str!("extensions.sql"));
108
109 conn.execute_batch(&self.schema)
110 .context("Failed to create indexing table")?;
111
112 tracing::debug!(schema = &self.schema, "Indexing table created");
113 }
114
115 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
116
117 {
118 let conn = self.connection.lock().unwrap();
119 conn.execute_batch(&format!(
122 "PRAGMA create_fts_index('{}', 'uuid', 'chunk', stemmer = 'porter',
123 stopwords = 'english', ignore = '(\\.|[^a-z])+',
124 strip_accents = 1, lower = 1, overwrite = 0);
125",
126 self.table_name
127 ))?;
128 }
129
130 tracing::info!("Setup completed");
131
132 Ok(())
133 }
134
135 async fn store(&self, node: indexing::Node<T>) -> Result<indexing::Node<T>> {
136 let lock = self.connection.lock().unwrap();
137 let mut stmt = lock.prepare(&self.node_upsert_sql)?;
138 self.store_node_on_stmt(&mut stmt, &node)?;
139
140 Ok(node)
141 }
142
143 async fn batch_store(&self, nodes: Vec<indexing::Node<T>>) -> indexing::IndexingStream<T> {
144 let mut new_nodes = Vec::with_capacity(nodes.len());
146
147 tracing::debug!("Waiting for transaction");
148 let mut conn = self.connection.lock().unwrap();
149 tracing::debug!("Got transaction");
150 let tx = match conn.transaction().context("Failed to start transaction") {
151 Ok(tx) => tx,
152 Err(err) => {
153 return Err(err).into();
154 }
155 };
156
157 tracing::debug!("Starting batch store");
158 {
159 let mut stmt = match tx
160 .prepare(&self.node_upsert_sql)
161 .context("Failed to prepare statement")
162 {
163 Ok(stmt) => stmt,
164 Err(err) => {
165 return Err(err).into();
166 }
167 };
168
169 for node in nodes {
170 new_nodes.push(self.store_node_on_stmt(&mut stmt, &node).map(|()| node));
171 }
172 };
173 if let Err(err) = tx.commit().context("Failed to commit transaction") {
174 return Err(err).into();
175 }
176
177 new_nodes.into()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use futures_util::TryStreamExt as _;
184 use indexing::{EmbeddedField, TextNode};
185
186 use super::*;
187
188 #[test_log::test(tokio::test)]
189 async fn test_persisting_nodes() {
190 let client = Duckdb::builder()
191 .connection(duckdb::Connection::open_in_memory().unwrap())
192 .table_name("test".to_string())
193 .with_vector(EmbeddedField::Combined, 3)
194 .build()
195 .unwrap();
196
197 let node = TextNode::new("Hello duckdb!")
198 .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
199 .to_owned();
200
201 client.setup().await.unwrap();
202 client.store(node.clone()).await.unwrap();
203
204 tracing::info!("Stored node");
205
206 {
207 let connection = client.connection.lock().unwrap();
208 let mut stmt = connection
209 .prepare("SELECT uuid,path,chunk FROM test")
210 .unwrap();
211 let node_iter = stmt
212 .query_map([], |row| {
213 Ok((
214 row.get::<_, String>(0).unwrap(), row.get::<_, String>(1).unwrap(), row.get::<_, String>(2).unwrap(), ))
218 })
219 .unwrap();
220
221 let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
222 assert_eq!(retrieved.len(), 1);
224 }
225
226 tracing::info!("Retrieved node");
227 let new_nodes = vec![node.clone(), node.clone(), node.clone()];
229 let stream_nodes: Vec<TextNode> = client
230 .batch_store(new_nodes)
231 .await
232 .try_collect()
233 .await
234 .unwrap();
235
236 assert_eq!(stream_nodes.len(), 3);
238 assert_eq!(stream_nodes[0], node);
239
240 tracing::info!("Batch stored nodes 1");
241 {
242 let connection = client.connection.lock().unwrap();
243 let mut stmt = connection
244 .prepare("SELECT uuid,path,chunk FROM test")
245 .unwrap();
246 let node_iter = stmt
247 .query_map([], |row| {
248 Ok((
249 row.get::<_, String>(0).unwrap(), row.get::<_, String>(1).unwrap(), row.get::<_, String>(2).unwrap(), ))
253 })
254 .unwrap();
255
256 let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
257 assert_eq!(retrieved.len(), 1);
258 }
259
260 let mut new_node = node.clone();
262 new_node.chunk = "Something else".into();
263
264 let new_nodes = vec![node.clone(), new_node.clone(), new_node.clone()];
265 let stream = client.batch_store(new_nodes).await;
266
267 let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap();
268 assert_eq!(streamed_nodes.len(), 3);
269 assert_eq!(streamed_nodes[0], node);
270
271 {
272 let connection = client.connection.lock().unwrap();
273 let mut stmt = connection
274 .prepare("SELECT uuid,path,chunk FROM test")
275 .unwrap();
276
277 let node_iter = stmt
278 .query_map([], |row| {
279 Ok((
280 row.get::<_, String>(0).unwrap(), row.get::<_, String>(1).unwrap(), row.get::<_, String>(2).unwrap(), ))
284 })
285 .unwrap();
286 let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
287 assert_eq!(retrieved.len(), 2);
288 }
289 }
290
291 #[ignore = "json types are acting up in duckdb at the moment"]
292 #[test_log::test(tokio::test)]
293 async fn test_with_metadata() {
294 let client = Duckdb::builder()
295 .connection(duckdb::Connection::open_in_memory().unwrap())
296 .table_name("test".to_string())
297 .with_vector(EmbeddedField::Combined, 3)
298 .build()
299 .unwrap();
300
301 let mut node = TextNode::new("Hello duckdb!")
302 .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
303 .to_owned();
304
305 node.metadata
306 .insert("filter".to_string(), "true".to_string());
307
308 client.setup().await.unwrap();
309 client.store(node).await.unwrap();
310
311 tracing::info!("Stored node");
312
313 let connection = client.connection.lock().unwrap();
314 let mut stmt = connection
315 .prepare("SELECT uuid,path,chunk FROM test")
316 .unwrap();
317
318 let node_iter = stmt
319 .query_map([], |row| {
320 Ok((
321 row.get::<_, String>(0).unwrap(), row.get::<_, String>(1).unwrap(), row.get::<_, String>(2).unwrap(), row.get::<_, Value>(3).unwrap(), ))
328 })
329 .unwrap();
330
331 let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
332 dbg!(&retrieved);
333 assert_eq!(retrieved.len(), 1);
335
336 let Value::Map(metadata) = &retrieved[0].3 else {
337 panic!("Expected metadata to be a map");
338 };
339
340 assert_eq!(metadata.keys().count(), 1);
341 assert_eq!(
342 metadata.get(&Value::Text("filter".into())).unwrap(),
343 &Value::Text("true".into())
344 );
345 }
346
347 #[test_log::test(tokio::test)]
348 async fn test_running_setup_twice() {
349 let client = Duckdb::builder()
350 .connection(duckdb::Connection::open_in_memory().unwrap())
351 .table_name("test".to_string())
352 .with_vector(EmbeddedField::Combined, 3)
353 .build()
354 .unwrap();
355
356 client.setup().await.unwrap();
357 client.setup().await.unwrap(); }
359
360 #[test_log::test(tokio::test)]
361 async fn test_persisted() {
362 let temp_db_path = temp_dir::TempDir::new().unwrap();
363 let temp_db_path = temp_db_path.path().join("test_duckdb.db");
364
365 let client = Duckdb::builder()
366 .connection(duckdb::Connection::open(temp_db_path).unwrap())
367 .table_name("test".to_string())
368 .with_vector(EmbeddedField::Combined, 3)
369 .build()
370 .unwrap();
371
372 let mut node = TextNode::new("Hello duckdb!")
373 .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
374 .to_owned();
375
376 node.metadata
377 .insert("filter".to_string(), "true".to_string());
378
379 client.setup().await.unwrap();
380 client.store(node).await.unwrap();
381
382 tracing::info!("Stored node");
383
384 let connection = client.connection.lock().unwrap();
385 let mut stmt = connection
386 .prepare("SELECT uuid,path,chunk FROM test")
387 .unwrap();
388
389 let node_iter = stmt
390 .query_map([], |row| {
391 Ok((
392 row.get::<_, String>(0).unwrap(), row.get::<_, String>(1).unwrap(), row.get::<_, String>(2).unwrap(), ))
396 })
397 .unwrap();
398
399 let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
400 dbg!(&retrieved);
401 assert_eq!(retrieved.len(), 1);
403 }
404}