Skip to main content

symbi_runtime/context/
vector_db_lance.rs

1//! LanceDB embedded vector backend.
2//!
3//! Zero-config: stores data in `./data/vector_db/` by default.
4//! No external services required — ships with the binary.
5
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use arrow_array::types::Float32Type;
11use arrow_array::{
12    Array, FixedSizeListArray, Int64Array, RecordBatch, RecordBatchIterator, StringArray,
13};
14use arrow_schema::{DataType, Field, Schema};
15use async_trait::async_trait;
16use futures::TryStreamExt;
17use lancedb::query::{ExecutableQuery, QueryBase};
18use serde_json::Value;
19use tokio::sync::RwLock;
20
21use crate::context::types::{
22    ContextError, ContextItem, KnowledgeItem, KnowledgeSource, KnowledgeType, MemoryItem,
23    VectorBatchOperation, VectorId,
24};
25use crate::context::vector_db::VectorDatabaseStats;
26use crate::context::vector_db_trait::{DistanceMetric, VectorDb};
27use crate::types::AgentId;
28
29/// Configuration for the embedded LanceDB backend.
30#[derive(Debug, Clone)]
31pub struct LanceDbConfig {
32    /// Path to the LanceDB data directory.
33    pub data_path: PathBuf,
34    /// Collection/table name.
35    pub collection_name: String,
36    /// Vector dimension.
37    pub vector_dimension: usize,
38    /// Distance metric.
39    pub distance_metric: DistanceMetric,
40}
41
42impl Default for LanceDbConfig {
43    fn default() -> Self {
44        Self {
45            data_path: PathBuf::from("./data/vector_db"),
46            collection_name: "symbiont_context".to_string(),
47            vector_dimension: 384,
48            distance_metric: DistanceMetric::Cosine,
49        }
50    }
51}
52
53pub struct LanceDbBackend {
54    db: lancedb::Connection,
55    config: LanceDbConfig,
56    table: Arc<RwLock<Option<lancedb::Table>>>,
57}
58
59impl LanceDbBackend {
60    pub async fn new(config: LanceDbConfig) -> Result<Self, ContextError> {
61        std::fs::create_dir_all(&config.data_path).map_err(|e| ContextError::StorageError {
62            reason: format!(
63                "Failed to create LanceDB data dir {:?}: {}",
64                config.data_path, e
65            ),
66        })?;
67
68        let db = lancedb::connect(config.data_path.to_str().unwrap_or("./data/vector_db"))
69            .execute()
70            .await
71            .map_err(|e| ContextError::StorageError {
72                reason: format!("Failed to connect to LanceDB: {}", e),
73            })?;
74
75        Ok(Self {
76            db,
77            config,
78            table: Arc::new(RwLock::new(None)),
79        })
80    }
81
82    fn build_schema(&self) -> Arc<Schema> {
83        Arc::new(Schema::new(vec![
84            Field::new("id", DataType::Utf8, false),
85            Field::new("content", DataType::Utf8, false),
86            Field::new("agent_id", DataType::Utf8, true),
87            Field::new(
88                "vector",
89                DataType::FixedSizeList(
90                    Arc::new(Field::new("item", DataType::Float32, true)),
91                    self.config.vector_dimension as i32,
92                ),
93                true,
94            ),
95            Field::new("metadata_json", DataType::Utf8, true),
96            Field::new("source", DataType::Utf8, true),
97            Field::new("content_type", DataType::Utf8, true),
98            Field::new("created_at", DataType::Int64, true),
99        ]))
100    }
101
102    fn distance_type(&self) -> lancedb::DistanceType {
103        match self.config.distance_metric {
104            DistanceMetric::Cosine => lancedb::DistanceType::Cosine,
105            DistanceMetric::Euclidean => lancedb::DistanceType::L2,
106            DistanceMetric::DotProduct => lancedb::DistanceType::Dot,
107        }
108    }
109
110    async fn get_table(&self) -> Result<lancedb::Table, ContextError> {
111        let guard = self.table.read().await;
112        guard.clone().ok_or_else(|| ContextError::StorageError {
113            reason: "LanceDB table not initialized — call initialize() first".into(),
114        })
115    }
116
117    #[allow(clippy::too_many_arguments)]
118    fn make_record_batch(
119        &self,
120        schema: &Arc<Schema>,
121        id: &str,
122        content: &str,
123        agent_id: &str,
124        embedding: &[f32],
125        metadata_json: &str,
126        source: &str,
127        content_type: &str,
128    ) -> Result<RecordBatch, ContextError> {
129        if embedding.len() != self.config.vector_dimension {
130            return Err(ContextError::StorageError {
131                reason: format!(
132                    "Dimension mismatch: expected {}, got {}",
133                    self.config.vector_dimension,
134                    embedding.len()
135                ),
136            });
137        }
138
139        // Build arrow array from embedding slice. Use from_iter_primitive
140        // with a single row containing the embedding values.
141        let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
142            std::iter::once(Some(embedding.iter().copied().map(Some))),
143            self.config.vector_dimension as i32,
144        );
145
146        let now_ms = std::time::SystemTime::now()
147            .duration_since(std::time::UNIX_EPOCH)
148            .unwrap_or_default()
149            .as_millis() as i64;
150
151        RecordBatch::try_new(
152            schema.clone(),
153            vec![
154                Arc::new(StringArray::from(vec![id])),
155                Arc::new(StringArray::from(vec![content])),
156                Arc::new(StringArray::from(vec![agent_id])),
157                Arc::new(vector_array),
158                Arc::new(StringArray::from(vec![metadata_json])),
159                Arc::new(StringArray::from(vec![source])),
160                Arc::new(StringArray::from(vec![content_type])),
161                Arc::new(Int64Array::from(vec![now_ms])),
162            ],
163        )
164        .map_err(|e| ContextError::StorageError {
165            reason: format!("Failed to create RecordBatch: {}", e),
166        })
167    }
168
169    fn parse_knowledge_item_from_batch(
170        &self,
171        batch: &RecordBatch,
172        row: usize,
173    ) -> Option<KnowledgeItem> {
174        let id_col = batch
175            .column_by_name("id")
176            .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
177        let content_col = batch
178            .column_by_name("content")
179            .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
180        let source_col = batch
181            .column_by_name("source")
182            .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
183        let created_col = batch
184            .column_by_name("created_at")
185            .and_then(|c| c.as_any().downcast_ref::<Int64Array>())?;
186
187        let id_str = id_col.value(row);
188        let content = content_col.value(row);
189        let source_str = source_col.value(row);
190        let created_ms = created_col.value(row);
191
192        let kid = uuid::Uuid::parse_str(id_str)
193            .ok()
194            .map(crate::context::types::KnowledgeId)
195            .unwrap_or_default();
196
197        let source = match source_str {
198            "UserProvided" => KnowledgeSource::UserProvided,
199            "Experience" => KnowledgeSource::Experience,
200            "Learning" => KnowledgeSource::Learning,
201            _ => KnowledgeSource::UserProvided,
202        };
203
204        let created_at =
205            std::time::UNIX_EPOCH + std::time::Duration::from_millis(created_ms.max(0) as u64);
206
207        Some(KnowledgeItem {
208            id: kid,
209            content: content.to_string(),
210            knowledge_type: KnowledgeType::Fact,
211            confidence: 0.9,
212            relevance_score: 0.8,
213            source,
214            created_at,
215        })
216    }
217}
218
219#[async_trait]
220impl VectorDb for LanceDbBackend {
221    async fn initialize(&self) -> Result<(), ContextError> {
222        let table_names =
223            self.db
224                .table_names()
225                .execute()
226                .await
227                .map_err(|e| ContextError::StorageError {
228                    reason: format!("Failed to list LanceDB tables: {}", e),
229                })?;
230
231        let table = if table_names.contains(&self.config.collection_name) {
232            self.db
233                .open_table(&self.config.collection_name)
234                .execute()
235                .await
236                .map_err(|e| ContextError::StorageError {
237                    reason: format!("Failed to open LanceDB table: {}", e),
238                })?
239        } else {
240            // Create table with an initial empty batch
241            let schema = self.build_schema();
242            let empty_batch = RecordBatch::new_empty(schema.clone());
243            let batches = RecordBatchIterator::new(vec![Ok(empty_batch)], schema);
244
245            self.db
246                .create_table(&self.config.collection_name, Box::new(batches))
247                .execute()
248                .await
249                .map_err(|e| ContextError::StorageError {
250                    reason: format!("Failed to create LanceDB table: {}", e),
251                })?
252        };
253
254        let mut guard = self.table.write().await;
255        *guard = Some(table);
256        Ok(())
257    }
258
259    async fn store_knowledge_item(
260        &self,
261        item: &KnowledgeItem,
262        embedding: Vec<f32>,
263    ) -> Result<VectorId, ContextError> {
264        let table = self.get_table().await?;
265        let schema = self.build_schema();
266        let vector_id = VectorId::new();
267
268        let metadata = serde_json::json!({
269            "knowledge_type": format!("{:?}", item.knowledge_type),
270            "confidence": item.confidence,
271            "relevance_score": item.relevance_score,
272        });
273
274        let source_str = format!("{:?}", item.source);
275
276        let batch = self.make_record_batch(
277            &schema,
278            &vector_id.to_string(),
279            &item.content,
280            "",
281            &embedding,
282            &metadata.to_string(),
283            &source_str,
284            "knowledge",
285        )?;
286
287        let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
288        table
289            .add(Box::new(batches))
290            .execute()
291            .await
292            .map_err(|e| ContextError::StorageError {
293                reason: format!("Failed to store knowledge item: {}", e),
294            })?;
295
296        Ok(vector_id)
297    }
298
299    async fn store_memory_item(
300        &self,
301        agent_id: AgentId,
302        memory: &MemoryItem,
303        embedding: Vec<f32>,
304    ) -> Result<VectorId, ContextError> {
305        let table = self.get_table().await?;
306        let schema = self.build_schema();
307        let vector_id = VectorId::new();
308
309        let metadata = serde_json::json!({
310            "memory_type": format!("{:?}", memory.memory_type),
311            "importance": memory.importance,
312        });
313
314        let batch = self.make_record_batch(
315            &schema,
316            &vector_id.to_string(),
317            &memory.content,
318            &agent_id.to_string(),
319            &embedding,
320            &metadata.to_string(),
321            "memory",
322            &format!("{:?}", memory.memory_type),
323        )?;
324
325        let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
326        table
327            .add(Box::new(batches))
328            .execute()
329            .await
330            .map_err(|e| ContextError::StorageError {
331                reason: format!("Failed to store memory item: {}", e),
332            })?;
333
334        Ok(vector_id)
335    }
336
337    async fn batch_store(
338        &self,
339        batch: VectorBatchOperation,
340    ) -> Result<Vec<VectorId>, ContextError> {
341        let mut ids = Vec::with_capacity(batch.items.len());
342        for item in &batch.items {
343            let vector_id = VectorId::new();
344            let embedding = item.embedding.clone().unwrap_or_default();
345            if embedding.is_empty() {
346                ids.push(vector_id);
347                continue;
348            }
349
350            let table = self.get_table().await?;
351            let schema = self.build_schema();
352            let metadata_json = serde_json::json!({
353                "source_id": item.metadata.source_id,
354                "tags": item.metadata.tags,
355            })
356            .to_string();
357
358            let record = self.make_record_batch(
359                &schema,
360                &vector_id.to_string(),
361                &item.content,
362                &item.metadata.agent_id.to_string(),
363                &embedding,
364                &metadata_json,
365                &item.metadata.source_id,
366                &format!("{:?}", item.metadata.content_type),
367            )?;
368
369            let batches = RecordBatchIterator::new(vec![Ok(record)], schema);
370            table.add(Box::new(batches)).execute().await.map_err(|e| {
371                ContextError::StorageError {
372                    reason: format!("Failed to batch store item: {}", e),
373                }
374            })?;
375
376            ids.push(vector_id);
377        }
378        Ok(ids)
379    }
380
381    async fn search_knowledge_base(
382        &self,
383        _agent_id: AgentId,
384        query_embedding: Vec<f32>,
385        limit: usize,
386    ) -> Result<Vec<KnowledgeItem>, ContextError> {
387        let table = self.get_table().await?;
388
389        let results = table
390            .vector_search(query_embedding)
391            .map_err(|e| ContextError::StorageError {
392                reason: format!("Failed to create vector search: {}", e),
393            })?
394            .distance_type(self.distance_type())
395            .limit(limit)
396            .execute()
397            .await
398            .map_err(|e| ContextError::StorageError {
399                reason: format!("Vector search failed: {}", e),
400            })?
401            .try_collect::<Vec<_>>()
402            .await
403            .map_err(|e| ContextError::StorageError {
404                reason: format!("Failed to collect search results: {}", e),
405            })?;
406
407        let mut items = Vec::new();
408        for batch in &results {
409            for row in 0..batch.num_rows() {
410                if let Some(item) = self.parse_knowledge_item_from_batch(batch, row) {
411                    items.push(item);
412                }
413            }
414        }
415
416        Ok(items)
417    }
418
419    async fn semantic_search(
420        &self,
421        agent_id: AgentId,
422        query_embedding: Vec<f32>,
423        limit: usize,
424        _threshold: f32,
425    ) -> Result<Vec<ContextItem>, ContextError> {
426        let knowledge_items = self
427            .search_knowledge_base(agent_id, query_embedding, limit)
428            .await?;
429
430        Ok(knowledge_items
431            .into_iter()
432            .map(|ki| ContextItem {
433                id: crate::context::types::ContextId::new(),
434                content: ki.content,
435                item_type: crate::context::types::ContextItemType::Knowledge(ki.knowledge_type),
436                relevance_score: ki.relevance_score,
437                timestamp: ki.created_at,
438                metadata: HashMap::new(),
439            })
440            .collect())
441    }
442
443    async fn advanced_search(
444        &self,
445        agent_id: AgentId,
446        query_embedding: Vec<f32>,
447        _filters: HashMap<String, String>,
448        limit: usize,
449        _threshold: f32,
450    ) -> Result<Vec<crate::context::types::VectorSearchResult>, ContextError> {
451        let knowledge_items = self
452            .search_knowledge_base(agent_id, query_embedding, limit)
453            .await?;
454
455        Ok(knowledge_items
456            .into_iter()
457            .map(|ki| crate::context::types::VectorSearchResult {
458                id: VectorId::new(),
459                content: ki.content,
460                score: ki.relevance_score,
461                metadata: HashMap::new(),
462                embedding: None,
463            })
464            .collect())
465    }
466
467    async fn delete_knowledge_item(&self, vector_id: VectorId) -> Result<(), ContextError> {
468        let table = self.get_table().await?;
469        table
470            .delete(&format!("id = '{}'", vector_id))
471            .await
472            .map_err(|e| ContextError::StorageError {
473                reason: format!("Failed to delete item: {}", e),
474            })?;
475        Ok(())
476    }
477
478    async fn batch_delete(&self, vector_ids: Vec<VectorId>) -> Result<(), ContextError> {
479        for id in vector_ids {
480            self.delete_knowledge_item(id).await?;
481        }
482        Ok(())
483    }
484
485    async fn update_metadata(
486        &self,
487        _vector_id: VectorId,
488        _metadata: HashMap<String, Value>,
489    ) -> Result<(), ContextError> {
490        // LanceDB doesn't have native metadata update — would need delete+reinsert
491        Ok(())
492    }
493
494    async fn get_stats(&self) -> Result<VectorDatabaseStats, ContextError> {
495        let table = self.get_table().await?;
496        let count = table
497            .count_rows(None)
498            .await
499            .map_err(|e| ContextError::StorageError {
500                reason: format!("Failed to count rows: {}", e),
501            })?;
502
503        Ok(VectorDatabaseStats {
504            total_vectors: count,
505            collection_size_bytes: 0,
506            avg_query_time_ms: 0.0,
507        })
508    }
509
510    async fn create_index(&self, _field_name: &str) -> Result<(), ContextError> {
511        // LanceDB creates indexes automatically during optimization
512        Ok(())
513    }
514
515    async fn optimize_collection(&self) -> Result<(), ContextError> {
516        let table = self.get_table().await?;
517        table
518            .optimize(lancedb::table::OptimizeAction::All)
519            .await
520            .map_err(|e| ContextError::StorageError {
521                reason: format!("Failed to optimize collection: {}", e),
522            })?;
523        Ok(())
524    }
525
526    async fn health_check(&self) -> Result<bool, ContextError> {
527        let result = self.db.table_names().execute().await;
528        Ok(result.is_ok())
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use crate::context::types::KnowledgeId;
536    use tempfile::TempDir;
537
538    fn make_test_config(tmp: &TempDir) -> LanceDbConfig {
539        LanceDbConfig {
540            data_path: tmp.path().to_path_buf(),
541            collection_name: "test_collection".to_string(),
542            vector_dimension: 4,
543            distance_metric: DistanceMetric::Cosine,
544        }
545    }
546
547    fn make_knowledge_item(content: &str) -> KnowledgeItem {
548        KnowledgeItem {
549            id: KnowledgeId::new(),
550            content: content.to_string(),
551            knowledge_type: KnowledgeType::Fact,
552            confidence: 0.9,
553            relevance_score: 0.8,
554            source: KnowledgeSource::UserProvided,
555            created_at: std::time::SystemTime::now(),
556        }
557    }
558
559    #[tokio::test]
560    async fn test_lance_initialize_and_health() {
561        let tmp = TempDir::new().unwrap();
562        let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
563        backend.initialize().await.unwrap();
564        assert!(backend.health_check().await.unwrap());
565    }
566
567    #[tokio::test]
568    async fn test_lance_store_and_count() {
569        let tmp = TempDir::new().unwrap();
570        let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
571        backend.initialize().await.unwrap();
572
573        let item = make_knowledge_item("Rust is a systems language");
574        let embedding = vec![0.1, 0.2, 0.3, 0.4];
575        let id = backend
576            .store_knowledge_item(&item, embedding)
577            .await
578            .unwrap();
579        assert_ne!(id, VectorId::default());
580
581        let stats = backend.get_stats().await.unwrap();
582        assert_eq!(stats.total_vectors, 1);
583    }
584
585    #[tokio::test]
586    async fn test_lance_search() {
587        let tmp = TempDir::new().unwrap();
588        let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
589        backend.initialize().await.unwrap();
590
591        let item1 = make_knowledge_item("Rust is fast");
592        backend
593            .store_knowledge_item(&item1, vec![1.0, 0.0, 0.0, 0.0])
594            .await
595            .unwrap();
596
597        let item2 = make_knowledge_item("Python is easy");
598        backend
599            .store_knowledge_item(&item2, vec![0.0, 1.0, 0.0, 0.0])
600            .await
601            .unwrap();
602
603        let agent_id = AgentId::new();
604        let results = backend
605            .search_knowledge_base(agent_id, vec![0.9, 0.1, 0.0, 0.0], 1)
606            .await
607            .unwrap();
608
609        assert_eq!(results.len(), 1);
610        assert!(results[0].content.contains("Rust"));
611    }
612
613    #[tokio::test]
614    async fn test_lance_delete() {
615        let tmp = TempDir::new().unwrap();
616        let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
617        backend.initialize().await.unwrap();
618
619        let item = make_knowledge_item("Delete me");
620        let id = backend
621            .store_knowledge_item(&item, vec![0.1, 0.2, 0.3, 0.4])
622            .await
623            .unwrap();
624
625        backend.delete_knowledge_item(id).await.unwrap();
626        let stats = backend.get_stats().await.unwrap();
627        assert_eq!(stats.total_vectors, 0);
628    }
629
630    #[tokio::test]
631    async fn test_lance_optimize() {
632        let tmp = TempDir::new().unwrap();
633        let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
634        backend.initialize().await.unwrap();
635        // Should not error on empty collection
636        backend.optimize_collection().await.unwrap();
637    }
638}