1use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use arrow_array::types::Float32Type;
11use arrow_array::{
12 Array, FixedSizeListArray, Int64Array, RecordBatch, RecordBatchIterator, StringArray,
13};
14use arrow_schema::{DataType, Field, Schema};
15use async_trait::async_trait;
16use futures::TryStreamExt;
17use lancedb::query::{ExecutableQuery, QueryBase};
18use serde_json::Value;
19use tokio::sync::RwLock;
20
21use crate::context::types::{
22 ContextError, ContextItem, KnowledgeItem, KnowledgeSource, KnowledgeType, MemoryItem,
23 VectorBatchOperation, VectorId,
24};
25use crate::context::vector_db::VectorDatabaseStats;
26use crate::context::vector_db_trait::{DistanceMetric, VectorDb};
27use crate::types::AgentId;
28
29#[derive(Debug, Clone)]
31pub struct LanceDbConfig {
32 pub data_path: PathBuf,
34 pub collection_name: String,
36 pub vector_dimension: usize,
38 pub distance_metric: DistanceMetric,
40}
41
42impl Default for LanceDbConfig {
43 fn default() -> Self {
44 Self {
45 data_path: PathBuf::from("./data/vector_db"),
46 collection_name: "symbiont_context".to_string(),
47 vector_dimension: 384,
48 distance_metric: DistanceMetric::Cosine,
49 }
50 }
51}
52
53pub struct LanceDbBackend {
54 db: lancedb::Connection,
55 config: LanceDbConfig,
56 table: Arc<RwLock<Option<lancedb::Table>>>,
57}
58
59impl LanceDbBackend {
60 pub async fn new(config: LanceDbConfig) -> Result<Self, ContextError> {
61 std::fs::create_dir_all(&config.data_path).map_err(|e| ContextError::StorageError {
62 reason: format!(
63 "Failed to create LanceDB data dir {:?}: {}",
64 config.data_path, e
65 ),
66 })?;
67
68 let db = lancedb::connect(config.data_path.to_str().unwrap_or("./data/vector_db"))
69 .execute()
70 .await
71 .map_err(|e| ContextError::StorageError {
72 reason: format!("Failed to connect to LanceDB: {}", e),
73 })?;
74
75 Ok(Self {
76 db,
77 config,
78 table: Arc::new(RwLock::new(None)),
79 })
80 }
81
82 fn build_schema(&self) -> Arc<Schema> {
83 Arc::new(Schema::new(vec![
84 Field::new("id", DataType::Utf8, false),
85 Field::new("content", DataType::Utf8, false),
86 Field::new("agent_id", DataType::Utf8, true),
87 Field::new(
88 "vector",
89 DataType::FixedSizeList(
90 Arc::new(Field::new("item", DataType::Float32, true)),
91 self.config.vector_dimension as i32,
92 ),
93 true,
94 ),
95 Field::new("metadata_json", DataType::Utf8, true),
96 Field::new("source", DataType::Utf8, true),
97 Field::new("content_type", DataType::Utf8, true),
98 Field::new("created_at", DataType::Int64, true),
99 ]))
100 }
101
102 fn distance_type(&self) -> lancedb::DistanceType {
103 match self.config.distance_metric {
104 DistanceMetric::Cosine => lancedb::DistanceType::Cosine,
105 DistanceMetric::Euclidean => lancedb::DistanceType::L2,
106 DistanceMetric::DotProduct => lancedb::DistanceType::Dot,
107 }
108 }
109
110 async fn get_table(&self) -> Result<lancedb::Table, ContextError> {
111 let guard = self.table.read().await;
112 guard.clone().ok_or_else(|| ContextError::StorageError {
113 reason: "LanceDB table not initialized — call initialize() first".into(),
114 })
115 }
116
117 #[allow(clippy::too_many_arguments)]
118 fn make_record_batch(
119 &self,
120 schema: &Arc<Schema>,
121 id: &str,
122 content: &str,
123 agent_id: &str,
124 embedding: &[f32],
125 metadata_json: &str,
126 source: &str,
127 content_type: &str,
128 ) -> Result<RecordBatch, ContextError> {
129 if embedding.len() != self.config.vector_dimension {
130 return Err(ContextError::StorageError {
131 reason: format!(
132 "Dimension mismatch: expected {}, got {}",
133 self.config.vector_dimension,
134 embedding.len()
135 ),
136 });
137 }
138
139 let vector_array = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
142 std::iter::once(Some(embedding.iter().copied().map(Some))),
143 self.config.vector_dimension as i32,
144 );
145
146 let now_ms = std::time::SystemTime::now()
147 .duration_since(std::time::UNIX_EPOCH)
148 .unwrap_or_default()
149 .as_millis() as i64;
150
151 RecordBatch::try_new(
152 schema.clone(),
153 vec![
154 Arc::new(StringArray::from(vec![id])),
155 Arc::new(StringArray::from(vec![content])),
156 Arc::new(StringArray::from(vec![agent_id])),
157 Arc::new(vector_array),
158 Arc::new(StringArray::from(vec![metadata_json])),
159 Arc::new(StringArray::from(vec![source])),
160 Arc::new(StringArray::from(vec![content_type])),
161 Arc::new(Int64Array::from(vec![now_ms])),
162 ],
163 )
164 .map_err(|e| ContextError::StorageError {
165 reason: format!("Failed to create RecordBatch: {}", e),
166 })
167 }
168
169 fn parse_knowledge_item_from_batch(
170 &self,
171 batch: &RecordBatch,
172 row: usize,
173 ) -> Option<KnowledgeItem> {
174 let id_col = batch
175 .column_by_name("id")
176 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
177 let content_col = batch
178 .column_by_name("content")
179 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
180 let source_col = batch
181 .column_by_name("source")
182 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
183 let created_col = batch
184 .column_by_name("created_at")
185 .and_then(|c| c.as_any().downcast_ref::<Int64Array>())?;
186
187 let id_str = id_col.value(row);
188 let content = content_col.value(row);
189 let source_str = source_col.value(row);
190 let created_ms = created_col.value(row);
191
192 let kid = uuid::Uuid::parse_str(id_str)
193 .ok()
194 .map(crate::context::types::KnowledgeId)
195 .unwrap_or_default();
196
197 let source = match source_str {
198 "UserProvided" => KnowledgeSource::UserProvided,
199 "Experience" => KnowledgeSource::Experience,
200 "Learning" => KnowledgeSource::Learning,
201 _ => KnowledgeSource::UserProvided,
202 };
203
204 let created_at =
205 std::time::UNIX_EPOCH + std::time::Duration::from_millis(created_ms.max(0) as u64);
206
207 Some(KnowledgeItem {
208 id: kid,
209 content: content.to_string(),
210 knowledge_type: KnowledgeType::Fact,
211 confidence: 0.9,
212 relevance_score: 0.8,
213 source,
214 created_at,
215 })
216 }
217}
218
219#[async_trait]
220impl VectorDb for LanceDbBackend {
221 async fn initialize(&self) -> Result<(), ContextError> {
222 let table_names =
223 self.db
224 .table_names()
225 .execute()
226 .await
227 .map_err(|e| ContextError::StorageError {
228 reason: format!("Failed to list LanceDB tables: {}", e),
229 })?;
230
231 let table = if table_names.contains(&self.config.collection_name) {
232 self.db
233 .open_table(&self.config.collection_name)
234 .execute()
235 .await
236 .map_err(|e| ContextError::StorageError {
237 reason: format!("Failed to open LanceDB table: {}", e),
238 })?
239 } else {
240 let schema = self.build_schema();
242 let empty_batch = RecordBatch::new_empty(schema.clone());
243 let batches = RecordBatchIterator::new(vec![Ok(empty_batch)], schema);
244
245 self.db
246 .create_table(&self.config.collection_name, Box::new(batches))
247 .execute()
248 .await
249 .map_err(|e| ContextError::StorageError {
250 reason: format!("Failed to create LanceDB table: {}", e),
251 })?
252 };
253
254 let mut guard = self.table.write().await;
255 *guard = Some(table);
256 Ok(())
257 }
258
259 async fn store_knowledge_item(
260 &self,
261 item: &KnowledgeItem,
262 embedding: Vec<f32>,
263 ) -> Result<VectorId, ContextError> {
264 let table = self.get_table().await?;
265 let schema = self.build_schema();
266 let vector_id = VectorId::new();
267
268 let metadata = serde_json::json!({
269 "knowledge_type": format!("{:?}", item.knowledge_type),
270 "confidence": item.confidence,
271 "relevance_score": item.relevance_score,
272 });
273
274 let source_str = format!("{:?}", item.source);
275
276 let batch = self.make_record_batch(
277 &schema,
278 &vector_id.to_string(),
279 &item.content,
280 "",
281 &embedding,
282 &metadata.to_string(),
283 &source_str,
284 "knowledge",
285 )?;
286
287 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
288 table
289 .add(Box::new(batches))
290 .execute()
291 .await
292 .map_err(|e| ContextError::StorageError {
293 reason: format!("Failed to store knowledge item: {}", e),
294 })?;
295
296 Ok(vector_id)
297 }
298
299 async fn store_memory_item(
300 &self,
301 agent_id: AgentId,
302 memory: &MemoryItem,
303 embedding: Vec<f32>,
304 ) -> Result<VectorId, ContextError> {
305 let table = self.get_table().await?;
306 let schema = self.build_schema();
307 let vector_id = VectorId::new();
308
309 let metadata = serde_json::json!({
310 "memory_type": format!("{:?}", memory.memory_type),
311 "importance": memory.importance,
312 });
313
314 let batch = self.make_record_batch(
315 &schema,
316 &vector_id.to_string(),
317 &memory.content,
318 &agent_id.to_string(),
319 &embedding,
320 &metadata.to_string(),
321 "memory",
322 &format!("{:?}", memory.memory_type),
323 )?;
324
325 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
326 table
327 .add(Box::new(batches))
328 .execute()
329 .await
330 .map_err(|e| ContextError::StorageError {
331 reason: format!("Failed to store memory item: {}", e),
332 })?;
333
334 Ok(vector_id)
335 }
336
337 async fn batch_store(
338 &self,
339 batch: VectorBatchOperation,
340 ) -> Result<Vec<VectorId>, ContextError> {
341 let mut ids = Vec::with_capacity(batch.items.len());
342 for item in &batch.items {
343 let vector_id = VectorId::new();
344 let embedding = item.embedding.clone().unwrap_or_default();
345 if embedding.is_empty() {
346 ids.push(vector_id);
347 continue;
348 }
349
350 let table = self.get_table().await?;
351 let schema = self.build_schema();
352 let metadata_json = serde_json::json!({
353 "source_id": item.metadata.source_id,
354 "tags": item.metadata.tags,
355 })
356 .to_string();
357
358 let record = self.make_record_batch(
359 &schema,
360 &vector_id.to_string(),
361 &item.content,
362 &item.metadata.agent_id.to_string(),
363 &embedding,
364 &metadata_json,
365 &item.metadata.source_id,
366 &format!("{:?}", item.metadata.content_type),
367 )?;
368
369 let batches = RecordBatchIterator::new(vec![Ok(record)], schema);
370 table.add(Box::new(batches)).execute().await.map_err(|e| {
371 ContextError::StorageError {
372 reason: format!("Failed to batch store item: {}", e),
373 }
374 })?;
375
376 ids.push(vector_id);
377 }
378 Ok(ids)
379 }
380
381 async fn search_knowledge_base(
382 &self,
383 _agent_id: AgentId,
384 query_embedding: Vec<f32>,
385 limit: usize,
386 ) -> Result<Vec<KnowledgeItem>, ContextError> {
387 let table = self.get_table().await?;
388
389 let results = table
390 .vector_search(query_embedding)
391 .map_err(|e| ContextError::StorageError {
392 reason: format!("Failed to create vector search: {}", e),
393 })?
394 .distance_type(self.distance_type())
395 .limit(limit)
396 .execute()
397 .await
398 .map_err(|e| ContextError::StorageError {
399 reason: format!("Vector search failed: {}", e),
400 })?
401 .try_collect::<Vec<_>>()
402 .await
403 .map_err(|e| ContextError::StorageError {
404 reason: format!("Failed to collect search results: {}", e),
405 })?;
406
407 let mut items = Vec::new();
408 for batch in &results {
409 for row in 0..batch.num_rows() {
410 if let Some(item) = self.parse_knowledge_item_from_batch(batch, row) {
411 items.push(item);
412 }
413 }
414 }
415
416 Ok(items)
417 }
418
419 async fn semantic_search(
420 &self,
421 agent_id: AgentId,
422 query_embedding: Vec<f32>,
423 limit: usize,
424 _threshold: f32,
425 ) -> Result<Vec<ContextItem>, ContextError> {
426 let knowledge_items = self
427 .search_knowledge_base(agent_id, query_embedding, limit)
428 .await?;
429
430 Ok(knowledge_items
431 .into_iter()
432 .map(|ki| ContextItem {
433 id: crate::context::types::ContextId::new(),
434 content: ki.content,
435 item_type: crate::context::types::ContextItemType::Knowledge(ki.knowledge_type),
436 relevance_score: ki.relevance_score,
437 timestamp: ki.created_at,
438 metadata: HashMap::new(),
439 })
440 .collect())
441 }
442
443 async fn advanced_search(
444 &self,
445 agent_id: AgentId,
446 query_embedding: Vec<f32>,
447 _filters: HashMap<String, String>,
448 limit: usize,
449 _threshold: f32,
450 ) -> Result<Vec<crate::context::types::VectorSearchResult>, ContextError> {
451 let knowledge_items = self
452 .search_knowledge_base(agent_id, query_embedding, limit)
453 .await?;
454
455 Ok(knowledge_items
456 .into_iter()
457 .map(|ki| crate::context::types::VectorSearchResult {
458 id: VectorId::new(),
459 content: ki.content,
460 score: ki.relevance_score,
461 metadata: HashMap::new(),
462 embedding: None,
463 })
464 .collect())
465 }
466
467 async fn delete_knowledge_item(&self, vector_id: VectorId) -> Result<(), ContextError> {
468 let table = self.get_table().await?;
469 table
470 .delete(&format!("id = '{}'", vector_id))
471 .await
472 .map_err(|e| ContextError::StorageError {
473 reason: format!("Failed to delete item: {}", e),
474 })?;
475 Ok(())
476 }
477
478 async fn batch_delete(&self, vector_ids: Vec<VectorId>) -> Result<(), ContextError> {
479 for id in vector_ids {
480 self.delete_knowledge_item(id).await?;
481 }
482 Ok(())
483 }
484
485 async fn update_metadata(
486 &self,
487 _vector_id: VectorId,
488 _metadata: HashMap<String, Value>,
489 ) -> Result<(), ContextError> {
490 Ok(())
492 }
493
494 async fn get_stats(&self) -> Result<VectorDatabaseStats, ContextError> {
495 let table = self.get_table().await?;
496 let count = table
497 .count_rows(None)
498 .await
499 .map_err(|e| ContextError::StorageError {
500 reason: format!("Failed to count rows: {}", e),
501 })?;
502
503 Ok(VectorDatabaseStats {
504 total_vectors: count,
505 collection_size_bytes: 0,
506 avg_query_time_ms: 0.0,
507 })
508 }
509
510 async fn create_index(&self, _field_name: &str) -> Result<(), ContextError> {
511 Ok(())
513 }
514
515 async fn optimize_collection(&self) -> Result<(), ContextError> {
516 let table = self.get_table().await?;
517 table
518 .optimize(lancedb::table::OptimizeAction::All)
519 .await
520 .map_err(|e| ContextError::StorageError {
521 reason: format!("Failed to optimize collection: {}", e),
522 })?;
523 Ok(())
524 }
525
526 async fn health_check(&self) -> Result<bool, ContextError> {
527 let result = self.db.table_names().execute().await;
528 Ok(result.is_ok())
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use crate::context::types::KnowledgeId;
536 use tempfile::TempDir;
537
538 fn make_test_config(tmp: &TempDir) -> LanceDbConfig {
539 LanceDbConfig {
540 data_path: tmp.path().to_path_buf(),
541 collection_name: "test_collection".to_string(),
542 vector_dimension: 4,
543 distance_metric: DistanceMetric::Cosine,
544 }
545 }
546
547 fn make_knowledge_item(content: &str) -> KnowledgeItem {
548 KnowledgeItem {
549 id: KnowledgeId::new(),
550 content: content.to_string(),
551 knowledge_type: KnowledgeType::Fact,
552 confidence: 0.9,
553 relevance_score: 0.8,
554 source: KnowledgeSource::UserProvided,
555 created_at: std::time::SystemTime::now(),
556 }
557 }
558
559 #[tokio::test]
560 async fn test_lance_initialize_and_health() {
561 let tmp = TempDir::new().unwrap();
562 let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
563 backend.initialize().await.unwrap();
564 assert!(backend.health_check().await.unwrap());
565 }
566
567 #[tokio::test]
568 async fn test_lance_store_and_count() {
569 let tmp = TempDir::new().unwrap();
570 let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
571 backend.initialize().await.unwrap();
572
573 let item = make_knowledge_item("Rust is a systems language");
574 let embedding = vec![0.1, 0.2, 0.3, 0.4];
575 let id = backend
576 .store_knowledge_item(&item, embedding)
577 .await
578 .unwrap();
579 assert_ne!(id, VectorId::default());
580
581 let stats = backend.get_stats().await.unwrap();
582 assert_eq!(stats.total_vectors, 1);
583 }
584
585 #[tokio::test]
586 async fn test_lance_search() {
587 let tmp = TempDir::new().unwrap();
588 let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
589 backend.initialize().await.unwrap();
590
591 let item1 = make_knowledge_item("Rust is fast");
592 backend
593 .store_knowledge_item(&item1, vec![1.0, 0.0, 0.0, 0.0])
594 .await
595 .unwrap();
596
597 let item2 = make_knowledge_item("Python is easy");
598 backend
599 .store_knowledge_item(&item2, vec![0.0, 1.0, 0.0, 0.0])
600 .await
601 .unwrap();
602
603 let agent_id = AgentId::new();
604 let results = backend
605 .search_knowledge_base(agent_id, vec![0.9, 0.1, 0.0, 0.0], 1)
606 .await
607 .unwrap();
608
609 assert_eq!(results.len(), 1);
610 assert!(results[0].content.contains("Rust"));
611 }
612
613 #[tokio::test]
614 async fn test_lance_delete() {
615 let tmp = TempDir::new().unwrap();
616 let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
617 backend.initialize().await.unwrap();
618
619 let item = make_knowledge_item("Delete me");
620 let id = backend
621 .store_knowledge_item(&item, vec![0.1, 0.2, 0.3, 0.4])
622 .await
623 .unwrap();
624
625 backend.delete_knowledge_item(id).await.unwrap();
626 let stats = backend.get_stats().await.unwrap();
627 assert_eq!(stats.total_vectors, 0);
628 }
629
630 #[tokio::test]
631 async fn test_lance_optimize() {
632 let tmp = TempDir::new().unwrap();
633 let backend = LanceDbBackend::new(make_test_config(&tmp)).await.unwrap();
634 backend.initialize().await.unwrap();
635 backend.optimize_collection().await.unwrap();
637 }
638}