swiftide_integrations/duckdb/
persist.rs

1use std::{borrow::Cow, path::Path};
2
3use anyhow::{Context as _, Result};
4use async_trait::async_trait;
5use duckdb::{
6    Statement, ToSql, params, params_from_iter,
7    types::{ToSqlOutput, Value},
8};
9use swiftide_core::{
10    Persist,
11    indexing::{self, Chunk, Metadata, Node},
12};
13use uuid::Uuid;
14
15use super::Duckdb;
16
17#[allow(dead_code)]
18enum TextNodeValues<'a> {
19    Uuid(Uuid),
20    Path(&'a Path),
21    Chunk(&'a str),
22    Metadata(&'a Metadata),
23    Embedding(Cow<'a, [f32]>),
24    Null,
25}
26
27impl ToSql for TextNodeValues<'_> {
28    fn to_sql(&self) -> duckdb::Result<ToSqlOutput<'_>> {
29        match self {
30            TextNodeValues::Uuid(uuid) => Ok(ToSqlOutput::Owned(uuid.to_string().into())),
31            // Should be borrow-able
32            TextNodeValues::Path(path) => Ok(path.to_string_lossy().to_string().into()),
33            TextNodeValues::Chunk(chunk) => chunk.to_sql(),
34            TextNodeValues::Metadata(_metadata) => {
35                unimplemented!("maps are not yet implemented for duckdb");
36                // Casting doesn't work either, the duckdb conversion is also not implemented :(
37            }
38            TextNodeValues::Embedding(vector) => {
39                let array_str = format!(
40                    "[{}]",
41                    vector
42                        .iter()
43                        .map(ToString::to_string)
44                        .collect::<Vec<_>>()
45                        .join(",")
46                );
47                Ok(ToSqlOutput::Owned(array_str.into()))
48            }
49            TextNodeValues::Null => Ok(ToSqlOutput::Owned(Value::Null)),
50        }
51    }
52}
53
54impl<T: Chunk + AsRef<str>> Duckdb<T> {
55    fn store_node_on_stmt(&self, stmt: &mut Statement<'_>, node: &Node<T>) -> Result<()> {
56        let mut values = vec![
57            TextNodeValues::Uuid(node.id()),
58            TextNodeValues::Chunk(node.chunk.as_ref()),
59            TextNodeValues::Path(&node.path),
60        ];
61
62        let Some(node_vectors) = &node.vectors else {
63            anyhow::bail!("Expected node to have vectors; cannot store into duckdb");
64        };
65
66        for field in self.vectors.keys() {
67            let Some(vector) = node_vectors.get(field) else {
68                anyhow::bail!("Expected vector for field {} in node", field);
69            };
70
71            values.push(TextNodeValues::Embedding(vector.into()));
72        }
73
74        // TODO: Investigate concurrency in duckdb, maybe optmistic if it works
75        stmt.execute(params_from_iter(values))
76            .context("Failed to store node")?;
77
78        Ok(())
79    }
80}
81
82#[async_trait]
83impl<T: Chunk + AsRef<str>> Persist for Duckdb<T> {
84    type Input = T;
85    type Output = T;
86
87    async fn setup(&self) -> Result<()> {
88        tracing::debug!("Setting up duckdb schema");
89
90        {
91            let conn = self.connection.lock().unwrap();
92
93            // Create if not exists does not seem to work with duckdb, so we check first
94            if conn
95                // Duckdb has issues with params it seems.
96                .query_row(&format!("SHOW {}", self.table_name()), params![], |row| {
97                    row.get::<_, String>(0)
98                })
99                .is_ok()
100            {
101                tracing::debug!("Indexing table already exists, skipping creation");
102                return Ok(());
103            }
104
105            // Install the extensions separately from the schema to avoid duckdb issues with random
106            // 'extension exists' errors
107            let _ = conn.execute_batch(include_str!("extensions.sql"));
108
109            conn.execute_batch(&self.schema)
110                .context("Failed to create indexing table")?;
111
112            tracing::debug!(schema = &self.schema, "Indexing table created");
113        }
114
115        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
116
117        {
118            let conn = self.connection.lock().unwrap();
119            // We need to run this separately to ensure the table is created before we create the
120            // index
121            conn.execute_batch(&format!(
122                "PRAGMA create_fts_index('{}', 'uuid', 'chunk', stemmer = 'porter',
123                 stopwords = 'english', ignore = '(\\.|[^a-z])+',
124                 strip_accents = 1, lower = 1, overwrite = 0);
125",
126                self.table_name
127            ))?;
128        }
129
130        tracing::info!("Setup completed");
131
132        Ok(())
133    }
134
135    async fn store(&self, node: indexing::Node<T>) -> Result<indexing::Node<T>> {
136        let lock = self.connection.lock().unwrap();
137        let mut stmt = lock.prepare(&self.node_upsert_sql)?;
138        self.store_node_on_stmt(&mut stmt, &node)?;
139
140        Ok(node)
141    }
142
143    async fn batch_store(&self, nodes: Vec<indexing::Node<T>>) -> indexing::IndexingStream<T> {
144        // TODO: Must batch
145        let mut new_nodes = Vec::with_capacity(nodes.len());
146
147        tracing::debug!("Waiting for transaction");
148        let mut conn = self.connection.lock().unwrap();
149        tracing::debug!("Got transaction");
150        let tx = match conn.transaction().context("Failed to start transaction") {
151            Ok(tx) => tx,
152            Err(err) => {
153                return Err(err).into();
154            }
155        };
156
157        tracing::debug!("Starting batch store");
158        {
159            let mut stmt = match tx
160                .prepare(&self.node_upsert_sql)
161                .context("Failed to prepare statement")
162            {
163                Ok(stmt) => stmt,
164                Err(err) => {
165                    return Err(err).into();
166                }
167            };
168
169            for node in nodes {
170                new_nodes.push(self.store_node_on_stmt(&mut stmt, &node).map(|()| node));
171            }
172        };
173        if let Err(err) = tx.commit().context("Failed to commit transaction") {
174            return Err(err).into();
175        }
176
177        new_nodes.into()
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use futures_util::TryStreamExt as _;
184    use indexing::{EmbeddedField, TextNode};
185
186    use super::*;
187
188    #[test_log::test(tokio::test)]
189    async fn test_persisting_nodes() {
190        let client = Duckdb::builder()
191            .connection(duckdb::Connection::open_in_memory().unwrap())
192            .table_name("test".to_string())
193            .with_vector(EmbeddedField::Combined, 3)
194            .build()
195            .unwrap();
196
197        let node = TextNode::new("Hello duckdb!")
198            .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
199            .to_owned();
200
201        client.setup().await.unwrap();
202        client.store(node.clone()).await.unwrap();
203
204        tracing::info!("Stored node");
205
206        {
207            let connection = client.connection.lock().unwrap();
208            let mut stmt = connection
209                .prepare("SELECT uuid,path,chunk FROM test")
210                .unwrap();
211            let node_iter = stmt
212                .query_map([], |row| {
213                    Ok((
214                        row.get::<_, String>(0).unwrap(), // id
215                        row.get::<_, String>(1).unwrap(), // chunk
216                        row.get::<_, String>(2).unwrap(), // path
217                    ))
218                })
219                .unwrap();
220
221            let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
222            //
223            assert_eq!(retrieved.len(), 1);
224        }
225
226        tracing::info!("Retrieved node");
227        // Verify the upsert and batch works
228        let new_nodes = vec![node.clone(), node.clone(), node.clone()];
229        let stream_nodes: Vec<TextNode> = client
230            .batch_store(new_nodes)
231            .await
232            .try_collect()
233            .await
234            .unwrap();
235
236        // let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap();
237        assert_eq!(stream_nodes.len(), 3);
238        assert_eq!(stream_nodes[0], node);
239
240        tracing::info!("Batch stored nodes 1");
241        {
242            let connection = client.connection.lock().unwrap();
243            let mut stmt = connection
244                .prepare("SELECT uuid,path,chunk FROM test")
245                .unwrap();
246            let node_iter = stmt
247                .query_map([], |row| {
248                    Ok((
249                        row.get::<_, String>(0).unwrap(), // id
250                        row.get::<_, String>(1).unwrap(), // chunk
251                        row.get::<_, String>(2).unwrap(), // path
252                    ))
253                })
254                .unwrap();
255
256            let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
257            assert_eq!(retrieved.len(), 1);
258        }
259
260        // Test batch store fully
261        let mut new_node = node.clone();
262        new_node.chunk = "Something else".into();
263
264        let new_nodes = vec![node.clone(), new_node.clone(), new_node.clone()];
265        let stream = client.batch_store(new_nodes).await;
266
267        let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap();
268        assert_eq!(streamed_nodes.len(), 3);
269        assert_eq!(streamed_nodes[0], node);
270
271        {
272            let connection = client.connection.lock().unwrap();
273            let mut stmt = connection
274                .prepare("SELECT uuid,path,chunk FROM test")
275                .unwrap();
276
277            let node_iter = stmt
278                .query_map([], |row| {
279                    Ok((
280                        row.get::<_, String>(0).unwrap(), // id
281                        row.get::<_, String>(1).unwrap(), // chunk
282                        row.get::<_, String>(2).unwrap(), // path
283                    ))
284                })
285                .unwrap();
286            let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
287            assert_eq!(retrieved.len(), 2);
288        }
289    }
290
291    #[ignore = "json types are acting up in duckdb at the moment"]
292    #[test_log::test(tokio::test)]
293    async fn test_with_metadata() {
294        let client = Duckdb::builder()
295            .connection(duckdb::Connection::open_in_memory().unwrap())
296            .table_name("test".to_string())
297            .with_vector(EmbeddedField::Combined, 3)
298            .build()
299            .unwrap();
300
301        let mut node = TextNode::new("Hello duckdb!")
302            .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
303            .to_owned();
304
305        node.metadata
306            .insert("filter".to_string(), "true".to_string());
307
308        client.setup().await.unwrap();
309        client.store(node).await.unwrap();
310
311        tracing::info!("Stored node");
312
313        let connection = client.connection.lock().unwrap();
314        let mut stmt = connection
315            .prepare("SELECT uuid,path,chunk FROM test")
316            .unwrap();
317
318        let node_iter = stmt
319            .query_map([], |row| {
320                Ok((
321                    row.get::<_, String>(0).unwrap(), // id
322                    row.get::<_, String>(1).unwrap(), // chunk
323                    row.get::<_, String>(2).unwrap(), // path
324                    row.get::<_, Value>(3).unwrap(),  // path
325                                                      // row.get::<_, String>(3).unwrap(), // metadata
326                                                      // row.get::<_, Vec<f32>>(4).unwrap(), // vector
327                ))
328            })
329            .unwrap();
330
331        let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
332        dbg!(&retrieved);
333        //
334        assert_eq!(retrieved.len(), 1);
335
336        let Value::Map(metadata) = &retrieved[0].3 else {
337            panic!("Expected metadata to be a map");
338        };
339
340        assert_eq!(metadata.keys().count(), 1);
341        assert_eq!(
342            metadata.get(&Value::Text("filter".into())).unwrap(),
343            &Value::Text("true".into())
344        );
345    }
346
347    #[test_log::test(tokio::test)]
348    async fn test_running_setup_twice() {
349        let client = Duckdb::builder()
350            .connection(duckdb::Connection::open_in_memory().unwrap())
351            .table_name("test".to_string())
352            .with_vector(EmbeddedField::Combined, 3)
353            .build()
354            .unwrap();
355
356        client.setup().await.unwrap();
357        client.setup().await.unwrap(); // Should not panic or error
358    }
359
360    #[test_log::test(tokio::test)]
361    async fn test_persisted() {
362        let temp_db_path = temp_dir::TempDir::new().unwrap();
363        let temp_db_path = temp_db_path.path().join("test_duckdb.db");
364
365        let client = Duckdb::builder()
366            .connection(duckdb::Connection::open(temp_db_path).unwrap())
367            .table_name("test".to_string())
368            .with_vector(EmbeddedField::Combined, 3)
369            .build()
370            .unwrap();
371
372        let mut node = TextNode::new("Hello duckdb!")
373            .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
374            .to_owned();
375
376        node.metadata
377            .insert("filter".to_string(), "true".to_string());
378
379        client.setup().await.unwrap();
380        client.store(node).await.unwrap();
381
382        tracing::info!("Stored node");
383
384        let connection = client.connection.lock().unwrap();
385        let mut stmt = connection
386            .prepare("SELECT uuid,path,chunk FROM test")
387            .unwrap();
388
389        let node_iter = stmt
390            .query_map([], |row| {
391                Ok((
392                    row.get::<_, String>(0).unwrap(), // id
393                    row.get::<_, String>(1).unwrap(), // chunk
394                    row.get::<_, String>(2).unwrap(), // path
395                ))
396            })
397            .unwrap();
398
399        let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
400        dbg!(&retrieved);
401        //
402        assert_eq!(retrieved.len(), 1);
403    }
404}