swiftide_integrations/duckdb/
mod.rs1use std::{
2 collections::HashMap,
3 sync::{Arc, Mutex},
4};
5
6use anyhow::{Context as _, Result};
7use derive_builder::Builder;
8use swiftide_core::indexing::EmbeddedField;
9use tera::Context;
10use tokio::sync::RwLock;
11
12pub mod node_cache;
13pub mod persist;
14pub mod retrieve;
15
16const DEFAULT_INDEXING_SCHEMA: &str = include_str!("schema.sql");
17const DEFAULT_UPSERT_QUERY: &str = include_str!("upsert.sql");
18
19#[derive(Clone, Builder)]
26#[builder(setter(into))]
27pub struct Duckdb {
28 #[builder(setter(custom))]
33 connection: Arc<Mutex<duckdb::Connection>>,
34
35 #[builder(default = "swiftide".into())]
37 table_name: String,
38
39 #[builder(default = self.default_schema())]
45 schema: String,
46
47 #[builder(default)]
49 vectors: HashMap<EmbeddedField, usize>,
50
51 #[builder(default = "256")]
53 batch_size: usize,
54
55 #[builder(private, default = self.default_node_upsert_sql())]
57 node_upsert_sql: String,
58
59 #[builder(default = "swiftide_cache".into())]
61 cache_table: String,
62
63 #[builder(private, default = Arc::new(false.into()))]
65 cache_table_created: Arc<RwLock<bool>>, #[builder(default = "String::new()")]
70 cache_key_prefix: String,
71
72 #[builder(default)]
75 #[allow(dead_code)]
76 upsert_vectors: bool,
77}
78
79impl std::fmt::Debug for Duckdb {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("Duckdb")
82 .field("connection", &"Arc<Mutex<duckdb::Connection>>")
83 .field("table_name", &self.table_name)
84 .field("batch_size", &self.batch_size)
85 .finish()
86 }
87}
88
89impl Duckdb {
90 pub fn builder() -> DuckdbBuilder {
91 DuckdbBuilder::default()
92 }
93
94 pub fn table_name(&self) -> &str {
96 &self.table_name
97 }
98
99 pub fn cache_table(&self) -> &str {
101 &self.cache_table
102 }
103
104 pub fn connection(&self) -> &Mutex<duckdb::Connection> {
106 &self.connection
107 }
108
109 pub fn create_vector_indices(&self) -> Result<()> {
123 let table_name = &self.table_name;
124 let mut conn = self.connection.lock().unwrap();
125 let tx = conn.transaction().context("Failed to start transaction")?;
126 {
127 for vector in self.vectors.keys() {
128 tx.execute(
129 &format!(
130 "CREATE INDEX IF NOT EXISTS idx_{vector} ON {table_name} USING hnsw ({vector}) WITH (metric = 'cosine')",
131 ),
132 [],
133 )
134 .context("Could not create index")?;
135 }
136 }
137 tx.commit().context("Failed to commit transaction")?;
138 Ok(())
139 }
140
141 pub async fn lazy_create_cache(&self) -> anyhow::Result<()> {
151 if !*self.cache_table_created.read().await {
152 let mut lock = self.cache_table_created.write().await;
153 let conn = self.connection.lock().unwrap();
154 conn.execute(
155 &format!(
156 "CREATE TABLE IF NOT EXISTS {} (uuid TEXT PRIMARY KEY, path TEXT)",
157 self.cache_table
158 ),
159 [],
160 )
161 .context("Could not create table")?;
162 conn.execute(
164 &format!(
165 "CREATE INDEX IF NOT EXISTS idx_path ON {} (path)",
166 self.cache_table
167 ),
168 [],
169 )
170 .context("Could not create index")?;
171 *lock = true;
172 }
173 Ok(())
174 }
175
176 pub fn node_key(&self, node: &swiftide_core::indexing::Node) -> String {
178 format!("{}.{}", self.cache_key_prefix, node.id())
179 }
180}
181
182impl DuckdbBuilder {
183 pub fn connection(&mut self, connection: impl Into<duckdb::Connection>) -> &mut Self {
184 self.connection = Some(Arc::new(Mutex::new(connection.into())));
185 self
186 }
187
188 pub fn with_vector(&mut self, field: EmbeddedField, size: usize) -> &mut Self {
189 self.vectors
190 .get_or_insert_with(HashMap::new)
191 .insert(field, size);
192 self
193 }
194
195 fn default_schema(&self) -> String {
196 let mut context = Context::default();
197 context.insert("table_name", &self.table_name);
198 context.insert("vectors", &self.vectors.clone().unwrap_or_default());
199
200 tera::Tera::one_off(DEFAULT_INDEXING_SCHEMA, &context, false)
201 .expect("Could not render schema; infalllible")
202 }
203
204 fn default_node_upsert_sql(&self) -> String {
205 let mut context = Context::default();
206 context.insert("table_name", &self.table_name);
207 context.insert("vectors", &self.vectors.clone().unwrap_or_default());
208 context.insert("upsert_vectors", &self.upsert_vectors);
209
210 context.insert(
211 "vector_field_names",
212 &self
213 .vectors
214 .as_ref()
215 .map(|v| v.keys().collect::<Vec<_>>())
216 .unwrap_or_default(),
217 );
218
219 tracing::info!("Rendering upsert sql");
220 tera::Tera::one_off(DEFAULT_UPSERT_QUERY, &context, false)
221 .expect("could not render upsert query; infallible")
222 }
223}