swiftide_integrations/lancedb/
mod.rs1use std::sync::Arc;
2
3use anyhow::Context as _;
4use anyhow::Result;
5use connection_pool::LanceDBConnectionPool;
6use connection_pool::LanceDBPoolManager;
7use deadpool::managed::Object;
8use derive_builder::Builder;
9use lancedb::arrow::arrow_schema::{DataType, Field, Schema};
10use swiftide_core::indexing::EmbeddedField;
11pub mod connection_pool;
12pub mod persist;
13pub mod retrieve;
14
15#[derive(Builder, Clone)]
43#[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))]
44#[allow(dead_code)]
45pub struct LanceDB {
46 #[builder(default = "self.default_connection_pool()?")]
49 connection_pool: Arc<LanceDBConnectionPool>,
50
51 uri: Option<String>,
53 #[builder(default = "Some(10)")]
55 pool_size: Option<usize>,
56
57 #[builder(default)]
59 api_key: Option<String>,
60 #[builder(default)]
62 region: Option<String>,
63 #[builder(default)]
65 storage_options: Vec<(String, String)>,
66
67 #[builder(private, default = "self.default_schema_from_fields()")]
68 schema: Arc<Schema>,
69
70 #[builder(default = "\"swiftide\".into()")]
73 table_name: String,
74
75 vector_size: Option<i32>,
78
79 #[builder(default = "256")]
81 batch_size: usize,
82
83 #[builder(default = "self.default_fields()")]
87 fields: Vec<FieldConfig>,
88}
89
90impl std::fmt::Debug for LanceDB {
91 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
92 f.debug_struct("LanceDB")
93 .field("schema", &self.schema)
94 .finish()
95 }
96}
97
98impl LanceDB {
99 pub fn builder() -> LanceDBBuilder {
100 LanceDBBuilder::default()
101 }
102
103 pub async fn get_connection(&self) -> Result<Object<LanceDBPoolManager>> {
109 Box::pin(self.connection_pool.get())
110 .await
111 .map_err(|e| anyhow::anyhow!(e))
112 }
113
114 pub async fn open_table(&self) -> Result<lancedb::Table> {
120 let conn = self.get_connection().await?;
121 conn.open_table(&self.table_name)
122 .execute()
123 .await
124 .context("Failed to open table")
125 }
126}
127
128impl LanceDBBuilder {
129 #[allow(clippy::missing_panics_doc)]
130 pub fn with_vector(&mut self, config: impl Into<VectorConfig>) -> &mut Self {
131 if self.fields.is_none() {
132 self.fields(self.default_fields());
133 }
134
135 self.fields
136 .as_mut()
137 .unwrap()
138 .push(FieldConfig::Vector(config.into()));
139
140 self
141 }
142
143 #[allow(clippy::missing_panics_doc)]
144 pub fn with_metadata(&mut self, config: impl Into<MetadataConfig>) -> &mut Self {
145 if self.fields.is_none() {
146 self.fields(self.default_fields());
147 }
148 self.fields
149 .as_mut()
150 .unwrap()
151 .push(FieldConfig::Metadata(config.into()));
152 self
153 }
154
155 #[allow(clippy::unused_self)]
156 fn default_fields(&self) -> Vec<FieldConfig> {
157 vec![FieldConfig::ID, FieldConfig::Chunk]
158 }
159
160 fn default_schema_from_fields(&self) -> Arc<Schema> {
161 let mut fields = Vec::new();
162 let vector_size = self.vector_size;
163
164 for field in self.fields.as_deref().unwrap_or(&self.default_fields()) {
165 match field {
166 FieldConfig::Vector(config) => {
167 let vector_size = config.vector_size.or(vector_size.flatten()).expect(
168 "Vector size should be set either in the field or in the LanceDB builder",
169 );
170
171 fields.push(Field::new(
172 config.field_name(),
173 DataType::FixedSizeList(
174 Arc::new(Field::new("item", DataType::Float32, true)),
175 vector_size,
176 ),
177 true,
178 ));
179 }
180 FieldConfig::Chunk => {
181 fields.push(Field::new(field.field_name(), DataType::Utf8, false));
182 }
183 FieldConfig::Metadata(_) => {
184 fields.push(Field::new(field.field_name(), DataType::Utf8, true));
185 }
186 FieldConfig::ID => {
187 fields.push(Field::new(
188 field.field_name(),
189 DataType::FixedSizeList(
190 Arc::new(Field::new("item", DataType::UInt8, true)),
191 16,
192 ),
193 false,
194 ));
195 }
196 }
197 }
198 Arc::new(Schema::new(fields))
199 }
200
201 fn default_connection_pool(&self) -> Result<Arc<LanceDBConnectionPool>> {
202 let mgr = LanceDBPoolManager::builder()
203 .uri(self.uri.clone().flatten().context("URI should be set")?)
204 .api_key(self.api_key.clone().flatten())
205 .region(self.region.clone().flatten())
206 .storage_options(self.storage_options.clone().unwrap_or_default())
207 .build()?;
208
209 LanceDBConnectionPool::builder(mgr)
210 .max_size(self.pool_size.flatten().unwrap_or(10))
211 .build()
212 .map(Arc::new)
213 .map_err(Into::into)
214 }
215}
216
217#[derive(Clone)]
218pub enum FieldConfig {
219 Vector(VectorConfig),
220 Metadata(MetadataConfig),
221 Chunk,
222 ID,
223}
224
225impl FieldConfig {
226 pub fn field_name(&self) -> String {
227 match self {
228 FieldConfig::Vector(config) => config.field_name(),
229 FieldConfig::Metadata(config) => config.field.clone(),
230 FieldConfig::Chunk => "chunk".into(),
231 FieldConfig::ID => "id".into(),
232 }
233 }
234}
235
236#[derive(Clone)]
237pub struct VectorConfig {
238 embedded_field: EmbeddedField,
239 vector_size: Option<i32>,
240}
241
242impl VectorConfig {
243 pub fn field_name(&self) -> String {
244 format!(
245 "vector_{}",
246 normalize_field_name(&self.embedded_field.to_string())
247 )
248 }
249}
250
251impl From<EmbeddedField> for VectorConfig {
252 fn from(val: EmbeddedField) -> Self {
253 VectorConfig {
254 embedded_field: val,
255 vector_size: None,
256 }
257 }
258}
259
260#[derive(Clone)]
261pub struct MetadataConfig {
262 field: String,
263 original_field: String,
264}
265
266impl<T: AsRef<str>> From<T> for MetadataConfig {
267 fn from(val: T) -> Self {
268 MetadataConfig {
269 field: normalize_field_name(val.as_ref()),
270 original_field: val.as_ref().to_string(),
271 }
272 }
273}
274
275pub(crate) fn normalize_field_name(field: &str) -> String {
276 field
277 .to_lowercase()
278 .replace(|c: char| !c.is_alphanumeric(), "_")
279}