swiftide_integrations/duckdb/
mod.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, Mutex},
4};
5
6use anyhow::{Context as _, Result};
7use derive_builder::Builder;
8use swiftide_core::indexing::EmbeddedField;
9use tera::Context;
10use tokio::sync::RwLock;
11
12pub mod node_cache;
13pub mod persist;
14pub mod retrieve;
15
16const DEFAULT_INDEXING_SCHEMA: &str = include_str!("schema.sql");
17const DEFAULT_UPSERT_QUERY: &str = include_str!("upsert.sql");
18
19/// Provides `Persist`, `Retrieve`, and `NodeCache` for duckdb
20///
21/// Unfortunately Metadata is not stored.
22///
23/// NOTE: The integration is not optimized for ultra large datasets / load. It might work, if it
24/// doesn't let us know <3.
25#[derive(Clone, Builder)]
26#[builder(setter(into))]
27pub struct Duckdb {
28    /// The connection to the database
29    ///
30    /// Note that this uses the tokio version of a mutex because the duckdb connection contains a
31    /// `RefCell`. This is not ideal, but it is what it is.
32    #[builder(setter(custom))]
33    connection: Arc<Mutex<duckdb::Connection>>,
34
35    /// The name of the table to use for storing nodes. Defaults to "swiftide".
36    #[builder(default = "swiftide".into())]
37    table_name: String,
38
39    /// The schema to use for the table
40    ///
41    /// Note that if you change the schema, you probably also need to change the upsert query.
42    ///
43    /// Additionally, if you intend to use vectors, you must install and load the vss extension.
44    #[builder(default = self.default_schema())]
45    schema: String,
46
47    // The vectors to be stored, field name -> size
48    #[builder(default)]
49    vectors: HashMap<EmbeddedField, usize>,
50
51    /// Batch size for storing nodes
52    #[builder(default = "256")]
53    batch_size: usize,
54
55    /// Sql to upsert a node
56    #[builder(private, default = self.default_node_upsert_sql())]
57    node_upsert_sql: String,
58
59    /// Name of the table to use for caching nodes. Defaults to `"swiftide_cache"`.
60    #[builder(default = "swiftide_cache".into())]
61    cache_table: String,
62
63    /// Tracks if the cache table has been created
64    #[builder(private, default = Arc::new(false.into()))]
65    cache_table_created: Arc<RwLock<bool>>, // note might need a mutex
66
67    /// Prefix to be used for keys stored in the database to avoid collisions. Can be used to
68    /// manually invalidate the cache.
69    #[builder(default = "String::new()")]
70    cache_key_prefix: String,
71
72    /// If enabled, vectors will be upserted with an ON CONFLICT DO UPDATE. If disabled, ON
73    /// conflict does nothing. Requires `duckdb` >= 1.2.1
74    #[builder(default)]
75    #[allow(dead_code)]
76    upsert_vectors: bool,
77}
78
79impl std::fmt::Debug for Duckdb {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("Duckdb")
82            .field("connection", &"Arc<Mutex<duckdb::Connection>>")
83            .field("table_name", &self.table_name)
84            .field("batch_size", &self.batch_size)
85            .finish()
86    }
87}
88
89impl Duckdb {
90    pub fn builder() -> DuckdbBuilder {
91        DuckdbBuilder::default()
92    }
93
94    /// Name of the indexing table
95    pub fn table_name(&self) -> &str {
96        &self.table_name
97    }
98
99    /// Name of the cache table
100    pub fn cache_table(&self) -> &str {
101        &self.cache_table
102    }
103
104    /// Returns the connection to the database
105    pub fn connection(&self) -> &Mutex<duckdb::Connection> {
106        &self.connection
107    }
108
109    /// Creates HNSW indices on the vector fields
110    ///
111    /// These are *not* persisted. You must recreate them on startup.
112    ///
113    /// If you want to persist them, refer to the duckdb documentation.
114    ///
115    /// # Errors
116    ///
117    /// Errors if the connection or statement fails
118    ///
119    /// # Panics
120    ///
121    /// If the mutex locking the connection is poisoned
122    pub fn create_vector_indices(&self) -> Result<()> {
123        let table_name = &self.table_name;
124        let mut conn = self.connection.lock().unwrap();
125        let tx = conn.transaction().context("Failed to start transaction")?;
126        {
127            for vector in self.vectors.keys() {
128                tx.execute(
129                    &format!(
130                        "CREATE INDEX IF NOT EXISTS idx_{vector} ON {table_name} USING hnsw ({vector}) WITH (metric = 'cosine')",
131                    ),
132                    [],
133                )
134                .context("Could not create index")?;
135            }
136        }
137        tx.commit().context("Failed to commit transaction")?;
138        Ok(())
139    }
140
141    /// Safely creates the cache table if it does not exist. Can be used concurrently
142    ///
143    /// # Errors
144    ///
145    /// Errors if the table or index could not be created
146    ///
147    /// # Panics
148    ///
149    /// If the mutex locking the connection is poisoned
150    pub async fn lazy_create_cache(&self) -> anyhow::Result<()> {
151        if !*self.cache_table_created.read().await {
152            let mut lock = self.cache_table_created.write().await;
153            let conn = self.connection.lock().unwrap();
154            conn.execute(
155                &format!(
156                    "CREATE TABLE IF NOT EXISTS {} (uuid TEXT PRIMARY KEY, path TEXT)",
157                    self.cache_table
158                ),
159                [],
160            )
161            .context("Could not create table")?;
162            // Create an extra index on path
163            conn.execute(
164                &format!(
165                    "CREATE INDEX IF NOT EXISTS idx_path ON {} (path)",
166                    self.cache_table
167                ),
168                [],
169            )
170            .context("Could not create index")?;
171            *lock = true;
172        }
173        Ok(())
174    }
175
176    /// Formats a node key for the cache table
177    pub fn node_key(&self, node: &swiftide_core::indexing::Node) -> String {
178        format!("{}.{}", self.cache_key_prefix, node.id())
179    }
180}
181
182impl DuckdbBuilder {
183    pub fn connection(&mut self, connection: impl Into<duckdb::Connection>) -> &mut Self {
184        self.connection = Some(Arc::new(Mutex::new(connection.into())));
185        self
186    }
187
188    pub fn with_vector(&mut self, field: EmbeddedField, size: usize) -> &mut Self {
189        self.vectors
190            .get_or_insert_with(HashMap::new)
191            .insert(field, size);
192        self
193    }
194
195    fn default_schema(&self) -> String {
196        let mut context = Context::default();
197        context.insert("table_name", &self.table_name);
198        context.insert("vectors", &self.vectors.clone().unwrap_or_default());
199
200        tera::Tera::one_off(DEFAULT_INDEXING_SCHEMA, &context, false)
201            .expect("Could not render schema; infalllible")
202    }
203
204    fn default_node_upsert_sql(&self) -> String {
205        let mut context = Context::default();
206        context.insert("table_name", &self.table_name);
207        context.insert("vectors", &self.vectors.clone().unwrap_or_default());
208        context.insert("upsert_vectors", &self.upsert_vectors);
209
210        context.insert(
211            "vector_field_names",
212            &self
213                .vectors
214                .as_ref()
215                .map(|v| v.keys().collect::<Vec<_>>())
216                .unwrap_or_default(),
217        );
218
219        tracing::info!("Rendering upsert sql");
220        tera::Tera::one_off(DEFAULT_UPSERT_QUERY, &context, false)
221            .expect("could not render upsert query; infallible")
222    }
223}