1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25 .replace("*/", "")
26}
27
28fn is_valid_id(s: &str) -> bool {
32 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
33}
34
35use arrow_array::{
36 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
37 RecordBatchIterator, StringArray,
38};
39use arrow_schema::{DataType, Field, Schema};
40use chrono::{TimeZone, Utc};
41use futures::TryStreamExt;
42use lancedb::Table;
43use lancedb::connect;
44use lancedb::query::{ExecutableQuery, QueryBase};
45use tracing::{debug, info};
46
47use crate::boost_similarity;
48use crate::chunker::{ChunkingConfig, chunk_content};
49use crate::document::ContentType;
50use crate::embedder::{EMBEDDING_DIM, Embedder};
51use crate::error::{Result, SedimentError};
52use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
53
54const CHUNK_THRESHOLD: usize = 1000;
56
57const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
59
60const CONFLICT_SEARCH_LIMIT: usize = 5;
62
63const MAX_CHUNKS_PER_ITEM: usize = 200;
67
68const EMBEDDING_BATCH_SIZE: usize = 32;
71
72pub struct Database {
74 db: lancedb::Connection,
75 embedder: Arc<Embedder>,
76 project_id: Option<String>,
77 items_table: Option<Table>,
78 chunks_table: Option<Table>,
79}
80
81#[derive(Debug, Default, Clone)]
83pub struct DatabaseStats {
84 pub item_count: usize,
85 pub chunk_count: usize,
86}
87
88const SCHEMA_VERSION: i32 = 2;
90
91fn item_schema() -> Schema {
93 Schema::new(vec![
94 Field::new("id", DataType::Utf8, false),
95 Field::new("content", DataType::Utf8, false),
96 Field::new("project_id", DataType::Utf8, true),
97 Field::new("is_chunked", DataType::Boolean, false),
98 Field::new("created_at", DataType::Int64, false), Field::new(
100 "vector",
101 DataType::FixedSizeList(
102 Arc::new(Field::new("item", DataType::Float32, true)),
103 EMBEDDING_DIM as i32,
104 ),
105 false,
106 ),
107 ])
108}
109
110fn chunk_schema() -> Schema {
111 Schema::new(vec![
112 Field::new("id", DataType::Utf8, false),
113 Field::new("item_id", DataType::Utf8, false),
114 Field::new("chunk_index", DataType::Int32, false),
115 Field::new("content", DataType::Utf8, false),
116 Field::new("context", DataType::Utf8, true),
117 Field::new(
118 "vector",
119 DataType::FixedSizeList(
120 Arc::new(Field::new("item", DataType::Float32, true)),
121 EMBEDDING_DIM as i32,
122 ),
123 false,
124 ),
125 ])
126}
127
128impl Database {
129 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
131 Self::open_with_project(path, None).await
132 }
133
134 pub async fn open_with_project(
136 path: impl Into<PathBuf>,
137 project_id: Option<String>,
138 ) -> Result<Self> {
139 let embedder = Arc::new(Embedder::new()?);
140 Self::open_with_embedder(path, project_id, embedder).await
141 }
142
143 pub async fn open_with_embedder(
155 path: impl Into<PathBuf>,
156 project_id: Option<String>,
157 embedder: Arc<Embedder>,
158 ) -> Result<Self> {
159 let path = path.into();
160 info!("Opening database at {:?}", path);
161
162 if let Some(parent) = path.parent() {
164 std::fs::create_dir_all(parent).map_err(|e| {
165 SedimentError::Database(format!("Failed to create database directory: {}", e))
166 })?;
167 }
168
169 let db = connect(path.to_str().ok_or_else(|| {
170 SedimentError::Database("Database path contains invalid UTF-8".to_string())
171 })?)
172 .execute()
173 .await
174 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
175
176 let mut database = Self {
177 db,
178 embedder,
179 project_id,
180 items_table: None,
181 chunks_table: None,
182 };
183
184 database.ensure_tables().await?;
185 database.ensure_vector_index().await?;
186
187 Ok(database)
188 }
189
190 pub fn set_project_id(&mut self, project_id: Option<String>) {
192 self.project_id = project_id;
193 }
194
195 pub fn project_id(&self) -> Option<&str> {
197 self.project_id.as_deref()
198 }
199
200 async fn ensure_tables(&mut self) -> Result<()> {
202 let mut table_names = self
204 .db
205 .table_names()
206 .execute()
207 .await
208 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
209
210 if table_names.contains(&"items_migrated".to_string()) {
212 info!("Detected interrupted migration, recovering...");
213 self.recover_interrupted_migration(&table_names).await?;
214 table_names =
216 self.db.table_names().execute().await.map_err(|e| {
217 SedimentError::Database(format!("Failed to list tables: {}", e))
218 })?;
219 }
220
221 if table_names.contains(&"items".to_string()) {
223 let needs_migration = self.check_needs_migration().await?;
224 if needs_migration {
225 info!("Migrating database schema to version {}", SCHEMA_VERSION);
226 self.migrate_schema().await?;
227 }
228 }
229
230 if table_names.contains(&"items".to_string()) {
232 self.items_table =
233 Some(self.db.open_table("items").execute().await.map_err(|e| {
234 SedimentError::Database(format!("Failed to open items: {}", e))
235 })?);
236 }
237
238 if table_names.contains(&"chunks".to_string()) {
240 self.chunks_table =
241 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
242 SedimentError::Database(format!("Failed to open chunks: {}", e))
243 })?);
244 }
245
246 Ok(())
247 }
248
249 async fn check_needs_migration(&self) -> Result<bool> {
251 let table = self.db.open_table("items").execute().await.map_err(|e| {
252 SedimentError::Database(format!("Failed to open items for check: {}", e))
253 })?;
254
255 let schema = table
256 .schema()
257 .await
258 .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
259
260 let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
262 Ok(has_tags)
263 }
264
265 async fn recover_interrupted_migration(&mut self, table_names: &[String]) -> Result<()> {
276 let has_items = table_names.contains(&"items".to_string());
277
278 if !has_items {
279 info!("Recovery case A: restoring items from items_migrated");
281 let staging = self
282 .db
283 .open_table("items_migrated")
284 .execute()
285 .await
286 .map_err(|e| {
287 SedimentError::Database(format!("Failed to open staging table: {}", e))
288 })?;
289
290 let results = staging
291 .query()
292 .execute()
293 .await
294 .map_err(|e| SedimentError::Database(format!("Recovery query failed: {}", e)))?
295 .try_collect::<Vec<_>>()
296 .await
297 .map_err(|e| SedimentError::Database(format!("Recovery collect failed: {}", e)))?;
298
299 let schema = Arc::new(item_schema());
300 let new_table = self
301 .db
302 .create_empty_table("items", schema.clone())
303 .execute()
304 .await
305 .map_err(|e| {
306 SedimentError::Database(format!("Failed to create items table: {}", e))
307 })?;
308
309 if !results.is_empty() {
310 let batches = RecordBatchIterator::new(results.into_iter().map(Ok), schema);
311 new_table
312 .add(Box::new(batches))
313 .execute()
314 .await
315 .map_err(|e| {
316 SedimentError::Database(format!("Failed to restore items: {}", e))
317 })?;
318 }
319
320 self.db.drop_table("items_migrated").await.map_err(|e| {
321 SedimentError::Database(format!("Failed to drop staging table: {}", e))
322 })?;
323 info!("Recovery case A completed");
324 } else {
325 let has_old_schema = self.check_needs_migration().await?;
327
328 if has_old_schema {
329 info!("Recovery case B: dropping incomplete staging table");
331 self.db.drop_table("items_migrated").await.map_err(|e| {
332 SedimentError::Database(format!("Failed to drop staging table: {}", e))
333 })?;
334 } else {
336 info!("Recovery case C: dropping leftover staging table");
338 self.db.drop_table("items_migrated").await.map_err(|e| {
339 SedimentError::Database(format!("Failed to drop staging table: {}", e))
340 })?;
341 }
342 }
343
344 Ok(())
345 }
346
347 async fn migrate_schema(&mut self) -> Result<()> {
361 info!("Starting schema migration...");
362
363 let old_table = self
365 .db
366 .open_table("items")
367 .execute()
368 .await
369 .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
370
371 let results = old_table
372 .query()
373 .execute()
374 .await
375 .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
376 .try_collect::<Vec<_>>()
377 .await
378 .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
379
380 let mut new_batches = Vec::new();
382 for batch in &results {
383 let converted = self.convert_batch_to_new_schema(batch)?;
384 new_batches.push(converted);
385 }
386
387 let old_count: usize = results.iter().map(|b| b.num_rows()).sum();
389 let new_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
390 if old_count != new_count {
391 return Err(SedimentError::Database(format!(
392 "Migration row count mismatch: old={}, new={}",
393 old_count, new_count
394 )));
395 }
396 info!("Migrating {} items to new schema", old_count);
397
398 let table_names = self
400 .db
401 .table_names()
402 .execute()
403 .await
404 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
405 if table_names.contains(&"items_migrated".to_string()) {
406 self.db.drop_table("items_migrated").await.map_err(|e| {
407 SedimentError::Database(format!("Failed to drop stale staging: {}", e))
408 })?;
409 }
410
411 let schema = Arc::new(item_schema());
413 let staging_table = self
414 .db
415 .create_empty_table("items_migrated", schema.clone())
416 .execute()
417 .await
418 .map_err(|e| {
419 SedimentError::Database(format!("Failed to create staging table: {}", e))
420 })?;
421
422 if !new_batches.is_empty() {
423 let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema.clone());
424 staging_table
425 .add(Box::new(batches))
426 .execute()
427 .await
428 .map_err(|e| {
429 SedimentError::Database(format!("Failed to insert into staging: {}", e))
430 })?;
431 }
432
433 let staging_count = staging_table
435 .count_rows(None)
436 .await
437 .map_err(|e| SedimentError::Database(format!("Failed to count staging rows: {}", e)))?;
438 if staging_count != old_count {
439 let _ = self.db.drop_table("items_migrated").await;
441 return Err(SedimentError::Database(format!(
442 "Staging row count mismatch: expected {}, got {}",
443 old_count, staging_count
444 )));
445 }
446
447 self.db.drop_table("items").await.map_err(|e| {
449 SedimentError::Database(format!("Failed to drop old items table: {}", e))
450 })?;
451
452 let staging_data = staging_table
454 .query()
455 .execute()
456 .await
457 .map_err(|e| SedimentError::Database(format!("Failed to read staging: {}", e)))?
458 .try_collect::<Vec<_>>()
459 .await
460 .map_err(|e| SedimentError::Database(format!("Failed to collect staging: {}", e)))?;
461
462 let new_table = self
463 .db
464 .create_empty_table("items", schema.clone())
465 .execute()
466 .await
467 .map_err(|e| {
468 SedimentError::Database(format!("Failed to create new items table: {}", e))
469 })?;
470
471 if !staging_data.is_empty() {
472 let batches = RecordBatchIterator::new(staging_data.into_iter().map(Ok), schema);
473 new_table
474 .add(Box::new(batches))
475 .execute()
476 .await
477 .map_err(|e| {
478 SedimentError::Database(format!("Failed to insert migrated items: {}", e))
479 })?;
480 }
481
482 self.db
484 .drop_table("items_migrated")
485 .await
486 .map_err(|e| SedimentError::Database(format!("Failed to drop staging table: {}", e)))?;
487
488 info!("Schema migration completed successfully");
489 Ok(())
490 }
491
492 fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
494 let schema = Arc::new(item_schema());
495
496 let id_col = batch
498 .column_by_name("id")
499 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
500 .clone();
501
502 let content_col = batch
503 .column_by_name("content")
504 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
505 .clone();
506
507 let project_id_col = batch
508 .column_by_name("project_id")
509 .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
510 .clone();
511
512 let is_chunked_col = batch
513 .column_by_name("is_chunked")
514 .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
515 .clone();
516
517 let created_at_col = batch
518 .column_by_name("created_at")
519 .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
520 .clone();
521
522 let vector_col = batch
523 .column_by_name("vector")
524 .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
525 .clone();
526
527 RecordBatch::try_new(
528 schema,
529 vec![
530 id_col,
531 content_col,
532 project_id_col,
533 is_chunked_col,
534 created_at_col,
535 vector_col,
536 ],
537 )
538 .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
539 }
540
541 async fn ensure_vector_index(&self) -> Result<()> {
546 const MIN_ROWS_FOR_INDEX: usize = 256;
547
548 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
549 if let Some(table) = table_opt {
550 let row_count = table.count_rows(None).await.unwrap_or(0);
551 if row_count < MIN_ROWS_FOR_INDEX {
552 continue;
553 }
554
555 let indices = table.list_indices().await.unwrap_or_default();
557
558 let has_vector_index = indices
559 .iter()
560 .any(|idx| idx.columns.contains(&"vector".to_string()));
561
562 if !has_vector_index {
563 info!(
564 "Creating vector index on {} table ({} rows)",
565 name, row_count
566 );
567 match table
568 .create_index(&["vector"], lancedb::index::Index::Auto)
569 .execute()
570 .await
571 {
572 Ok(_) => info!("Vector index created on {} table", name),
573 Err(e) => {
574 tracing::warn!("Failed to create vector index on {}: {}", name, e);
576 }
577 }
578 }
579 }
580 }
581
582 Ok(())
583 }
584
585 async fn get_items_table(&mut self) -> Result<&Table> {
587 if self.items_table.is_none() {
588 let schema = Arc::new(item_schema());
589 let table = self
590 .db
591 .create_empty_table("items", schema)
592 .execute()
593 .await
594 .map_err(|e| {
595 SedimentError::Database(format!("Failed to create items table: {}", e))
596 })?;
597 self.items_table = Some(table);
598 }
599 Ok(self.items_table.as_ref().unwrap())
600 }
601
602 async fn get_chunks_table(&mut self) -> Result<&Table> {
604 if self.chunks_table.is_none() {
605 let schema = Arc::new(chunk_schema());
606 let table = self
607 .db
608 .create_empty_table("chunks", schema)
609 .execute()
610 .await
611 .map_err(|e| {
612 SedimentError::Database(format!("Failed to create chunks table: {}", e))
613 })?;
614 self.chunks_table = Some(table);
615 }
616 Ok(self.chunks_table.as_ref().unwrap())
617 }
618
619 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
626 if item.project_id.is_none() {
628 item.project_id = self.project_id.clone();
629 }
630
631 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
634 item.is_chunked = should_chunk;
635
636 let embedding_text = item.embedding_text();
638 let embedding = self.embedder.embed(&embedding_text)?;
639 item.embedding = embedding;
640
641 let table = self.get_items_table().await?;
643 let batch = item_to_batch(&item)?;
644 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
645
646 table
647 .add(Box::new(batches))
648 .execute()
649 .await
650 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
651
652 if should_chunk {
654 let content_type = detect_content_type(&item.content);
655 let config = ChunkingConfig::default();
656 let mut chunk_results = chunk_content(&item.content, content_type, &config);
657
658 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
660 tracing::warn!(
661 "Chunk count {} exceeds limit {}, truncating",
662 chunk_results.len(),
663 MAX_CHUNKS_PER_ITEM
664 );
665 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
666 }
667
668 if let Err(e) = self.store_chunks(&item.id, &chunk_results).await {
669 let _ = self.delete_item(&item.id).await;
671 return Err(e);
672 }
673
674 debug!(
675 "Stored item: {} with {} chunks",
676 item.id,
677 chunk_results.len()
678 );
679 } else {
680 debug!("Stored item: {} (no chunking)", item.id);
681 }
682
683 let potential_conflicts = self
686 .find_similar_items_by_vector(
687 &item.embedding,
688 Some(&item.id),
689 CONFLICT_SIMILARITY_THRESHOLD,
690 CONFLICT_SEARCH_LIMIT,
691 )
692 .await
693 .unwrap_or_default();
694
695 Ok(StoreResult {
696 id: item.id,
697 potential_conflicts,
698 })
699 }
700
701 async fn store_chunks(
703 &mut self,
704 item_id: &str,
705 chunk_results: &[crate::chunker::ChunkResult],
706 ) -> Result<()> {
707 let embedder = self.embedder.clone();
708 let chunks_table = self.get_chunks_table().await?;
709
710 let chunk_texts: Vec<&str> = chunk_results.iter().map(|cr| cr.content.as_str()).collect();
712 let mut all_embeddings = Vec::with_capacity(chunk_texts.len());
713 for batch_start in (0..chunk_texts.len()).step_by(EMBEDDING_BATCH_SIZE) {
714 let batch_end = (batch_start + EMBEDDING_BATCH_SIZE).min(chunk_texts.len());
715 let batch_embeddings = embedder.embed_batch(&chunk_texts[batch_start..batch_end])?;
716 all_embeddings.extend(batch_embeddings);
717 }
718
719 let mut all_chunk_batches = Vec::with_capacity(chunk_results.len());
721 for (i, (chunk_result, embedding)) in chunk_results.iter().zip(all_embeddings).enumerate() {
722 let mut chunk = Chunk::new(item_id, i, &chunk_result.content);
723 if let Some(ctx) = &chunk_result.context {
724 chunk = chunk.with_context(ctx);
725 }
726 chunk.embedding = embedding;
727 all_chunk_batches.push(chunk_to_batch(&chunk)?);
728 }
729
730 if !all_chunk_batches.is_empty() {
732 let schema = Arc::new(chunk_schema());
733 let batches = RecordBatchIterator::new(all_chunk_batches.into_iter().map(Ok), schema);
734 chunks_table
735 .add(Box::new(batches))
736 .execute()
737 .await
738 .map_err(|e| SedimentError::Database(format!("Failed to store chunks: {}", e)))?;
739 }
740
741 Ok(())
742 }
743
744 pub async fn search_items(
746 &mut self,
747 query: &str,
748 limit: usize,
749 filters: ItemFilters,
750 ) -> Result<Vec<SearchResult>> {
751 let limit = limit.min(1000);
753 self.ensure_vector_index().await?;
755
756 let query_embedding = self.embedder.embed(query)?;
758 let min_similarity = filters.min_similarity.unwrap_or(0.3);
759
760 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
762 std::collections::HashMap::new();
763
764 if let Some(table) = &self.items_table {
766 let query_builder = table
767 .vector_search(query_embedding.clone())
768 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
769 .limit(limit * 2);
770
771 let results = query_builder
772 .execute()
773 .await
774 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
775 .try_collect::<Vec<_>>()
776 .await
777 .map_err(|e| {
778 SedimentError::Database(format!("Failed to collect results: {}", e))
779 })?;
780
781 for batch in results {
782 let items = batch_to_items(&batch)?;
783 let distances = batch
784 .column_by_name("_distance")
785 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
786
787 for (i, item) in items.into_iter().enumerate() {
788 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
789 let similarity = 1.0 / (1.0 + distance);
790
791 if similarity < min_similarity {
792 continue;
793 }
794
795 let boosted_similarity = boost_similarity(
797 similarity,
798 item.project_id.as_deref(),
799 self.project_id.as_deref(),
800 );
801
802 let result = SearchResult::from_item(&item, boosted_similarity);
803 results_map
804 .entry(item.id.clone())
805 .or_insert((result, boosted_similarity));
806 }
807 }
808 }
809
810 if let Some(chunks_table) = &self.chunks_table {
812 let chunk_results = chunks_table
813 .vector_search(query_embedding)
814 .map_err(|e| {
815 SedimentError::Database(format!("Failed to build chunk search: {}", e))
816 })?
817 .limit(limit * 3)
818 .execute()
819 .await
820 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
821 .try_collect::<Vec<_>>()
822 .await
823 .map_err(|e| {
824 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
825 })?;
826
827 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
829 std::collections::HashMap::new();
830
831 for batch in chunk_results {
832 let chunks = batch_to_chunks(&batch)?;
833 let distances = batch
834 .column_by_name("_distance")
835 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
836
837 for (i, chunk) in chunks.into_iter().enumerate() {
838 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
839 let similarity = 1.0 / (1.0 + distance);
840
841 if similarity < min_similarity {
842 continue;
843 }
844
845 chunk_matches
847 .entry(chunk.item_id.clone())
848 .and_modify(|(content, best_sim)| {
849 if similarity > *best_sim {
850 *content = chunk.content.clone();
851 *best_sim = similarity;
852 }
853 })
854 .or_insert((chunk.content.clone(), similarity));
855 }
856 }
857
858 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
860 if let Some(item) = self.get_item(&item_id).await? {
861 let boosted_similarity = boost_similarity(
863 chunk_similarity,
864 item.project_id.as_deref(),
865 self.project_id.as_deref(),
866 );
867
868 let result =
869 SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
870
871 results_map
873 .entry(item_id)
874 .and_modify(|(existing, existing_sim)| {
875 if boosted_similarity > *existing_sim {
876 *existing = result.clone();
877 *existing_sim = boosted_similarity;
878 }
879 })
880 .or_insert((result, boosted_similarity));
881 }
882 }
883 }
884
885 let mut search_results: Vec<SearchResult> =
887 results_map.into_values().map(|(r, _)| r).collect();
888 search_results.sort_by(|a, b| {
889 b.similarity
890 .partial_cmp(&a.similarity)
891 .unwrap_or(std::cmp::Ordering::Equal)
892 });
893 search_results.truncate(limit);
894
895 Ok(search_results)
896 }
897
898 pub async fn find_similar_items(
903 &mut self,
904 content: &str,
905 min_similarity: f32,
906 limit: usize,
907 ) -> Result<Vec<ConflictInfo>> {
908 let embedding = self.embedder.embed(content)?;
909 self.find_similar_items_by_vector(&embedding, None, min_similarity, limit)
910 .await
911 }
912
913 pub async fn find_similar_items_by_vector(
917 &self,
918 embedding: &[f32],
919 exclude_id: Option<&str>,
920 min_similarity: f32,
921 limit: usize,
922 ) -> Result<Vec<ConflictInfo>> {
923 let table = match &self.items_table {
924 Some(t) => t,
925 None => return Ok(Vec::new()),
926 };
927
928 let results = table
929 .vector_search(embedding.to_vec())
930 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
931 .limit(limit)
932 .execute()
933 .await
934 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
935 .try_collect::<Vec<_>>()
936 .await
937 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
938
939 let mut conflicts = Vec::new();
940
941 for batch in results {
942 let items = batch_to_items(&batch)?;
943 let distances = batch
944 .column_by_name("_distance")
945 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
946
947 for (i, item) in items.into_iter().enumerate() {
948 if exclude_id.is_some_and(|eid| eid == item.id) {
949 continue;
950 }
951
952 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
953 let similarity = 1.0 / (1.0 + distance);
954
955 if similarity >= min_similarity {
956 conflicts.push(ConflictInfo {
957 id: item.id,
958 content: item.content,
959 similarity,
960 });
961 }
962 }
963 }
964
965 conflicts.sort_by(|a, b| {
967 b.similarity
968 .partial_cmp(&a.similarity)
969 .unwrap_or(std::cmp::Ordering::Equal)
970 });
971
972 Ok(conflicts)
973 }
974
975 pub async fn list_items(
977 &mut self,
978 _filters: ItemFilters,
979 limit: Option<usize>,
980 scope: crate::ListScope,
981 ) -> Result<Vec<Item>> {
982 let table = match &self.items_table {
983 Some(t) => t,
984 None => return Ok(Vec::new()),
985 };
986
987 let mut filter_parts = Vec::new();
988
989 match scope {
991 crate::ListScope::Project => {
992 if let Some(ref pid) = self.project_id {
993 if !is_valid_id(pid) {
994 return Err(SedimentError::Database(
995 "Invalid project_id for list filter".to_string(),
996 ));
997 }
998 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
999 } else {
1000 return Ok(Vec::new());
1002 }
1003 }
1004 crate::ListScope::Global => {
1005 filter_parts.push("project_id IS NULL".to_string());
1006 }
1007 crate::ListScope::All => {
1008 }
1010 }
1011
1012 let mut query = table.query();
1013
1014 if !filter_parts.is_empty() {
1015 let filter_str = filter_parts.join(" AND ");
1016 query = query.only_if(filter_str);
1017 }
1018
1019 if let Some(l) = limit {
1020 query = query.limit(l);
1021 }
1022
1023 let results = query
1024 .execute()
1025 .await
1026 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1027 .try_collect::<Vec<_>>()
1028 .await
1029 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1030
1031 let mut items = Vec::new();
1032 for batch in results {
1033 items.extend(batch_to_items(&batch)?);
1034 }
1035
1036 Ok(items)
1037 }
1038
1039 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
1041 if !is_valid_id(id) {
1042 return Ok(None);
1043 }
1044 let table = match &self.items_table {
1045 Some(t) => t,
1046 None => return Ok(None),
1047 };
1048
1049 let results = table
1050 .query()
1051 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
1052 .limit(1)
1053 .execute()
1054 .await
1055 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1056 .try_collect::<Vec<_>>()
1057 .await
1058 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1059
1060 for batch in results {
1061 let items = batch_to_items(&batch)?;
1062 if let Some(item) = items.into_iter().next() {
1063 return Ok(Some(item));
1064 }
1065 }
1066
1067 Ok(None)
1068 }
1069
1070 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
1072 let table = match &self.items_table {
1073 Some(t) => t,
1074 None => return Ok(Vec::new()),
1075 };
1076
1077 if ids.is_empty() {
1078 return Ok(Vec::new());
1079 }
1080
1081 let quoted: Vec<String> = ids
1082 .iter()
1083 .filter(|id| is_valid_id(id))
1084 .map(|id| format!("'{}'", sanitize_sql_string(id)))
1085 .collect();
1086 if quoted.is_empty() {
1087 return Ok(Vec::new());
1088 }
1089 let filter = format!("id IN ({})", quoted.join(", "));
1090
1091 let results = table
1092 .query()
1093 .only_if(filter)
1094 .execute()
1095 .await
1096 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
1097 .try_collect::<Vec<_>>()
1098 .await
1099 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
1100
1101 let mut items = Vec::new();
1102 for batch in results {
1103 items.extend(batch_to_items(&batch)?);
1104 }
1105
1106 Ok(items)
1107 }
1108
1109 pub async fn delete_item(&self, id: &str) -> Result<bool> {
1112 if !is_valid_id(id) {
1113 return Ok(false);
1114 }
1115 let table = match &self.items_table {
1117 Some(t) => t,
1118 None => return Ok(false),
1119 };
1120
1121 let exists = self.get_item(id).await?.is_some();
1122 if !exists {
1123 return Ok(false);
1124 }
1125
1126 if let Some(chunks_table) = &self.chunks_table {
1128 chunks_table
1129 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
1130 .await
1131 .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
1132 }
1133
1134 table
1136 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
1137 .await
1138 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
1139
1140 Ok(true)
1141 }
1142
1143 pub async fn stats(&self) -> Result<DatabaseStats> {
1145 let mut stats = DatabaseStats::default();
1146
1147 if let Some(table) = &self.items_table {
1148 stats.item_count = table
1149 .count_rows(None)
1150 .await
1151 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1152 }
1153
1154 if let Some(table) = &self.chunks_table {
1155 stats.chunk_count = table
1156 .count_rows(None)
1157 .await
1158 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1159 }
1160
1161 Ok(stats)
1162 }
1163}
1164
1165pub async fn migrate_project_id(
1172 db_path: &std::path::Path,
1173 old_id: &str,
1174 new_id: &str,
1175) -> Result<u64> {
1176 if !is_valid_id(old_id) || !is_valid_id(new_id) {
1177 return Err(SedimentError::Database(
1178 "Invalid project ID for migration".to_string(),
1179 ));
1180 }
1181
1182 let db = connect(db_path.to_str().ok_or_else(|| {
1183 SedimentError::Database("Database path contains invalid UTF-8".to_string())
1184 })?)
1185 .execute()
1186 .await
1187 .map_err(|e| SedimentError::Database(format!("Failed to connect for migration: {}", e)))?;
1188
1189 let table_names = db
1190 .table_names()
1191 .execute()
1192 .await
1193 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
1194
1195 let mut total_updated = 0u64;
1196
1197 if table_names.contains(&"items".to_string()) {
1198 let table =
1199 db.open_table("items").execute().await.map_err(|e| {
1200 SedimentError::Database(format!("Failed to open items table: {}", e))
1201 })?;
1202
1203 let updated = table
1204 .update()
1205 .only_if(format!("project_id = '{}'", sanitize_sql_string(old_id)))
1206 .column("project_id", format!("'{}'", sanitize_sql_string(new_id)))
1207 .execute()
1208 .await
1209 .map_err(|e| SedimentError::Database(format!("Failed to migrate items: {}", e)))?;
1210
1211 total_updated += updated;
1212 info!(
1213 "Migrated {} items from project {} to {}",
1214 updated, old_id, new_id
1215 );
1216 }
1217
1218 Ok(total_updated)
1219}
1220
1221pub fn score_with_decay(
1232 similarity: f32,
1233 now: i64,
1234 created_at: i64,
1235 access_count: u32,
1236 last_accessed_at: Option<i64>,
1237) -> f32 {
1238 if !similarity.is_finite() {
1240 return 0.0;
1241 }
1242
1243 let reference_time = last_accessed_at.unwrap_or(created_at);
1244 let age_secs = (now - reference_time).max(0) as f64;
1245 let age_days = age_secs / 86400.0;
1246
1247 let freshness = 1.0 / (1.0 + age_days / 30.0);
1248 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
1249
1250 let result = similarity * (freshness * frequency) as f32;
1251 if result.is_finite() { result } else { 0.0 }
1252}
1253
1254fn detect_content_type(content: &str) -> ContentType {
1258 let trimmed = content.trim();
1259
1260 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
1262 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
1263 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
1264 {
1265 return ContentType::Json;
1266 }
1267
1268 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
1272 let lines: Vec<&str> = trimmed.lines().take(10).collect();
1273 let yaml_key_count = lines
1274 .iter()
1275 .filter(|line| {
1276 let l = line.trim();
1277 !l.is_empty()
1280 && !l.starts_with('#')
1281 && !l.contains("://")
1282 && l.contains(": ")
1283 && l.split(": ").next().is_some_and(|key| {
1284 let k = key.trim_start_matches("- ");
1285 !k.is_empty()
1286 && k.chars()
1287 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1288 })
1289 })
1290 .count();
1291 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1293 return ContentType::Yaml;
1294 }
1295 }
1296
1297 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1299 return ContentType::Markdown;
1300 }
1301
1302 let code_patterns = [
1305 "fn ",
1306 "pub fn ",
1307 "def ",
1308 "class ",
1309 "function ",
1310 "const ",
1311 "let ",
1312 "var ",
1313 "import ",
1314 "export ",
1315 "struct ",
1316 "impl ",
1317 "trait ",
1318 ];
1319 let has_code_pattern = trimmed.lines().any(|line| {
1320 let l = line.trim();
1321 code_patterns.iter().any(|p| l.starts_with(p))
1322 });
1323 if has_code_pattern {
1324 return ContentType::Code;
1325 }
1326
1327 ContentType::Text
1328}
1329
1330fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1333 let schema = Arc::new(item_schema());
1334
1335 let id = StringArray::from(vec![item.id.as_str()]);
1336 let content = StringArray::from(vec![item.content.as_str()]);
1337 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1338 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1339 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1340
1341 let vector = create_embedding_array(&item.embedding)?;
1342
1343 RecordBatch::try_new(
1344 schema,
1345 vec![
1346 Arc::new(id),
1347 Arc::new(content),
1348 Arc::new(project_id),
1349 Arc::new(is_chunked),
1350 Arc::new(created_at),
1351 Arc::new(vector),
1352 ],
1353 )
1354 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1355}
1356
1357fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1358 let mut items = Vec::new();
1359
1360 let id_col = batch
1361 .column_by_name("id")
1362 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1363 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1364
1365 let content_col = batch
1366 .column_by_name("content")
1367 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1368 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1369
1370 let project_id_col = batch
1371 .column_by_name("project_id")
1372 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1373
1374 let is_chunked_col = batch
1375 .column_by_name("is_chunked")
1376 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1377
1378 let created_at_col = batch
1379 .column_by_name("created_at")
1380 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1381
1382 let vector_col = batch
1383 .column_by_name("vector")
1384 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1385
1386 for i in 0..batch.num_rows() {
1387 let id = id_col.value(i).to_string();
1388 let content = content_col.value(i).to_string();
1389
1390 let project_id = project_id_col.and_then(|c| {
1391 if c.is_null(i) {
1392 None
1393 } else {
1394 Some(c.value(i).to_string())
1395 }
1396 });
1397
1398 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1399
1400 let created_at = created_at_col
1401 .map(|c| {
1402 Utc.timestamp_opt(c.value(i), 0)
1403 .single()
1404 .unwrap_or_else(Utc::now)
1405 })
1406 .unwrap_or_else(Utc::now);
1407
1408 let embedding = vector_col
1409 .and_then(|col| {
1410 let value = col.value(i);
1411 value
1412 .as_any()
1413 .downcast_ref::<Float32Array>()
1414 .map(|arr| arr.values().to_vec())
1415 })
1416 .unwrap_or_default();
1417
1418 let item = Item {
1419 id,
1420 content,
1421 embedding,
1422 project_id,
1423 is_chunked,
1424 created_at,
1425 };
1426
1427 items.push(item);
1428 }
1429
1430 Ok(items)
1431}
1432
1433fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1434 let schema = Arc::new(chunk_schema());
1435
1436 let id = StringArray::from(vec![chunk.id.as_str()]);
1437 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1438 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1439 let content = StringArray::from(vec![chunk.content.as_str()]);
1440 let context = StringArray::from(vec![chunk.context.as_deref()]);
1441
1442 let vector = create_embedding_array(&chunk.embedding)?;
1443
1444 RecordBatch::try_new(
1445 schema,
1446 vec![
1447 Arc::new(id),
1448 Arc::new(item_id),
1449 Arc::new(chunk_index),
1450 Arc::new(content),
1451 Arc::new(context),
1452 Arc::new(vector),
1453 ],
1454 )
1455 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1456}
1457
1458fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1459 let mut chunks = Vec::new();
1460
1461 let id_col = batch
1462 .column_by_name("id")
1463 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1464 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1465
1466 let item_id_col = batch
1467 .column_by_name("item_id")
1468 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1469 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1470
1471 let chunk_index_col = batch
1472 .column_by_name("chunk_index")
1473 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1474 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1475
1476 let content_col = batch
1477 .column_by_name("content")
1478 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1479 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1480
1481 let context_col = batch
1482 .column_by_name("context")
1483 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1484
1485 for i in 0..batch.num_rows() {
1486 let id = id_col.value(i).to_string();
1487 let item_id = item_id_col.value(i).to_string();
1488 let chunk_index = chunk_index_col.value(i) as usize;
1489 let content = content_col.value(i).to_string();
1490 let context = context_col.and_then(|c| {
1491 if c.is_null(i) {
1492 None
1493 } else {
1494 Some(c.value(i).to_string())
1495 }
1496 });
1497
1498 let chunk = Chunk {
1499 id,
1500 item_id,
1501 chunk_index,
1502 content,
1503 embedding: Vec::new(),
1504 context,
1505 };
1506
1507 chunks.push(chunk);
1508 }
1509
1510 Ok(chunks)
1511}
1512
1513fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1514 let values = Float32Array::from(embedding.to_vec());
1515 let field = Arc::new(Field::new("item", DataType::Float32, true));
1516
1517 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1518 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1519}
1520
1521#[cfg(test)]
1522mod tests {
1523 use super::*;
1524
1525 #[test]
1526 fn test_score_with_decay_fresh_item() {
1527 let now = 1700000000i64;
1528 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1530 let expected = 0.8 * 1.0 * 1.0;
1532 assert!((score - expected).abs() < 0.001, "got {}", score);
1533 }
1534
1535 #[test]
1536 fn test_score_with_decay_30_day_old() {
1537 let now = 1700000000i64;
1538 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1540 let expected = 0.8 * 0.5;
1542 assert!((score - expected).abs() < 0.001, "got {}", score);
1543 }
1544
1545 #[test]
1546 fn test_score_with_decay_frequent_access() {
1547 let now = 1700000000i64;
1548 let created = now - 30 * 86400;
1549 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1551 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1553 let expected = 0.8 * 1.0 * freq as f32;
1554 assert!((score - expected).abs() < 0.01, "got {}", score);
1555 }
1556
1557 #[test]
1558 fn test_score_with_decay_old_and_unused() {
1559 let now = 1700000000i64;
1560 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1562 let expected = 0.8 * 0.25;
1564 assert!((score - expected).abs() < 0.001, "got {}", score);
1565 }
1566
1567 #[test]
1568 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1569 assert_eq!(sanitize_sql_string("hello"), "hello");
1570 assert_eq!(sanitize_sql_string("it's"), "it''s");
1571 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1572 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1573 }
1574
1575 #[test]
1576 fn test_sanitize_sql_string_strips_null_bytes() {
1577 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1578 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1579 assert_eq!(sanitize_sql_string("*/ OR 1=1"), " OR 1=1");
1581 assert_eq!(sanitize_sql_string("clean"), "clean");
1582 }
1583
1584 #[test]
1585 fn test_sanitize_sql_string_strips_semicolons() {
1586 assert_eq!(
1587 sanitize_sql_string("a; DROP TABLE items"),
1588 "a DROP TABLE items"
1589 );
1590 assert_eq!(sanitize_sql_string("normal;"), "normal");
1591 }
1592
1593 #[test]
1594 fn test_sanitize_sql_string_strips_comments() {
1595 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1597 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block ");
1599 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1601 assert_eq!(sanitize_sql_string("injected */ rest"), "injected rest");
1603 assert_eq!(sanitize_sql_string("*/"), "");
1605 }
1606
1607 #[test]
1608 fn test_sanitize_sql_string_adversarial_inputs() {
1609 assert_eq!(
1611 sanitize_sql_string("'; DROP TABLE items;--"),
1612 "'' DROP TABLE items"
1613 );
1614 assert_eq!(
1616 sanitize_sql_string("hello\u{200B}world"),
1617 "hello\u{200B}world"
1618 );
1619 assert_eq!(sanitize_sql_string(""), "");
1621 assert_eq!(sanitize_sql_string("\0;\0"), "");
1623 }
1624
1625 #[test]
1626 fn test_is_valid_id() {
1627 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1629 assert!(is_valid_id("abcdef0123456789"));
1630 assert!(!is_valid_id(""));
1632 assert!(!is_valid_id("'; DROP TABLE items;--"));
1633 assert!(!is_valid_id("hello world"));
1634 assert!(!is_valid_id("abc\0def"));
1635 assert!(!is_valid_id(&"a".repeat(65)));
1637 }
1638
1639 #[test]
1640 fn test_detect_content_type_yaml_not_prose() {
1641 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1643 let detected = detect_content_type(prose);
1644 assert_ne!(
1645 detected,
1646 ContentType::Yaml,
1647 "Prose with colons should not be detected as YAML"
1648 );
1649
1650 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1652 let detected = detect_content_type(yaml);
1653 assert_eq!(detected, ContentType::Yaml);
1654 }
1655
1656 #[test]
1657 fn test_detect_content_type_yaml_with_separator() {
1658 let yaml = "---\nname: test\nversion: 1.0";
1659 let detected = detect_content_type(yaml);
1660 assert_eq!(detected, ContentType::Yaml);
1661 }
1662
1663 #[test]
1664 fn test_chunk_threshold_uses_chars_not_bytes() {
1665 let emoji_content = "😀".repeat(500);
1668 assert_eq!(emoji_content.chars().count(), 500);
1669 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1672 assert!(
1673 !should_chunk,
1674 "500 chars should not exceed 1000-char threshold"
1675 );
1676
1677 let long_content = "a".repeat(1001);
1679 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1680 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1681 }
1682
1683 #[test]
1684 fn test_schema_version() {
1685 let version = SCHEMA_VERSION;
1687 assert!(version >= 2, "Schema version should be at least 2");
1688 }
1689
1690 fn old_item_schema() -> Schema {
1692 Schema::new(vec![
1693 Field::new("id", DataType::Utf8, false),
1694 Field::new("content", DataType::Utf8, false),
1695 Field::new("project_id", DataType::Utf8, true),
1696 Field::new("tags", DataType::Utf8, true), Field::new("is_chunked", DataType::Boolean, false),
1698 Field::new("created_at", DataType::Int64, false),
1699 Field::new(
1700 "vector",
1701 DataType::FixedSizeList(
1702 Arc::new(Field::new("item", DataType::Float32, true)),
1703 EMBEDDING_DIM as i32,
1704 ),
1705 false,
1706 ),
1707 ])
1708 }
1709
1710 fn old_item_batch(id: &str, content: &str) -> RecordBatch {
1712 let schema = Arc::new(old_item_schema());
1713 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1714 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1715 let vector = FixedSizeListArray::try_new(
1716 vector_field,
1717 EMBEDDING_DIM as i32,
1718 Arc::new(vector_values),
1719 None,
1720 )
1721 .unwrap();
1722
1723 RecordBatch::try_new(
1724 schema,
1725 vec![
1726 Arc::new(StringArray::from(vec![id])),
1727 Arc::new(StringArray::from(vec![content])),
1728 Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(BooleanArray::from(vec![false])),
1731 Arc::new(Int64Array::from(vec![1700000000i64])),
1732 Arc::new(vector),
1733 ],
1734 )
1735 .unwrap()
1736 }
1737
1738 #[tokio::test]
1739 #[ignore] async fn test_check_needs_migration_detects_old_schema() {
1741 let tmp = tempfile::TempDir::new().unwrap();
1742 let db_path = tmp.path().join("data");
1743
1744 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1746 .execute()
1747 .await
1748 .unwrap();
1749
1750 let schema = Arc::new(old_item_schema());
1751 let batch = old_item_batch("test-id-1", "old content");
1752 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1753 db_conn
1754 .create_table("items", Box::new(batches))
1755 .execute()
1756 .await
1757 .unwrap();
1758
1759 let db = Database {
1761 db: db_conn,
1762 embedder: Arc::new(Embedder::new().unwrap()),
1763 project_id: None,
1764 items_table: None,
1765 chunks_table: None,
1766 };
1767
1768 let needs_migration = db.check_needs_migration().await.unwrap();
1769 assert!(
1770 needs_migration,
1771 "Old schema with tags column should need migration"
1772 );
1773 }
1774
1775 #[tokio::test]
1776 #[ignore] async fn test_check_needs_migration_false_for_new_schema() {
1778 let tmp = tempfile::TempDir::new().unwrap();
1779 let db_path = tmp.path().join("data");
1780
1781 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1783 .execute()
1784 .await
1785 .unwrap();
1786
1787 let schema = Arc::new(item_schema());
1788 db_conn
1789 .create_empty_table("items", schema)
1790 .execute()
1791 .await
1792 .unwrap();
1793
1794 let db = Database {
1795 db: db_conn,
1796 embedder: Arc::new(Embedder::new().unwrap()),
1797 project_id: None,
1798 items_table: None,
1799 chunks_table: None,
1800 };
1801
1802 let needs_migration = db.check_needs_migration().await.unwrap();
1803 assert!(!needs_migration, "New schema should not need migration");
1804 }
1805
1806 #[tokio::test]
1807 #[ignore] async fn test_migrate_schema_preserves_data() {
1809 let tmp = tempfile::TempDir::new().unwrap();
1810 let db_path = tmp.path().join("data");
1811
1812 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1814 .execute()
1815 .await
1816 .unwrap();
1817
1818 let schema = Arc::new(old_item_schema());
1819 let batch1 = old_item_batch("id-aaa", "first item content");
1820 let batch2 = old_item_batch("id-bbb", "second item content");
1821 let batches = RecordBatchIterator::new(vec![Ok(batch1), Ok(batch2)], schema);
1822 db_conn
1823 .create_table("items", Box::new(batches))
1824 .execute()
1825 .await
1826 .unwrap();
1827 drop(db_conn);
1828
1829 let embedder = Arc::new(Embedder::new().unwrap());
1831 let db = Database::open_with_embedder(&db_path, None, embedder)
1832 .await
1833 .unwrap();
1834
1835 let needs_migration = db.check_needs_migration().await.unwrap();
1837 assert!(
1838 !needs_migration,
1839 "Schema should be migrated (no tags column)"
1840 );
1841
1842 let item_a = db.get_item("id-aaa").await.unwrap();
1844 assert!(item_a.is_some(), "Item id-aaa should be preserved");
1845 assert_eq!(item_a.unwrap().content, "first item content");
1846
1847 let item_b = db.get_item("id-bbb").await.unwrap();
1848 assert!(item_b.is_some(), "Item id-bbb should be preserved");
1849 assert_eq!(item_b.unwrap().content, "second item content");
1850
1851 let stats = db.stats().await.unwrap();
1853 assert_eq!(stats.item_count, 2, "Should have 2 items after migration");
1854 }
1855
1856 #[tokio::test]
1857 #[ignore] async fn test_recover_case_a_only_staging() {
1859 let tmp = tempfile::TempDir::new().unwrap();
1860 let db_path = tmp.path().join("data");
1861
1862 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1864 .execute()
1865 .await
1866 .unwrap();
1867
1868 let schema = Arc::new(item_schema());
1869 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1870 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1871 let vector = FixedSizeListArray::try_new(
1872 vector_field,
1873 EMBEDDING_DIM as i32,
1874 Arc::new(vector_values),
1875 None,
1876 )
1877 .unwrap();
1878
1879 let batch = RecordBatch::try_new(
1880 schema.clone(),
1881 vec![
1882 Arc::new(StringArray::from(vec!["staging-id"])),
1883 Arc::new(StringArray::from(vec!["staging content"])),
1884 Arc::new(StringArray::from(vec![None::<&str>])),
1885 Arc::new(BooleanArray::from(vec![false])),
1886 Arc::new(Int64Array::from(vec![1700000000i64])),
1887 Arc::new(vector),
1888 ],
1889 )
1890 .unwrap();
1891
1892 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1893 db_conn
1894 .create_table("items_migrated", Box::new(batches))
1895 .execute()
1896 .await
1897 .unwrap();
1898 drop(db_conn);
1899
1900 let embedder = Arc::new(Embedder::new().unwrap());
1902 let db = Database::open_with_embedder(&db_path, None, embedder)
1903 .await
1904 .unwrap();
1905
1906 let item = db.get_item("staging-id").await.unwrap();
1908 assert!(item.is_some(), "Item should be recovered from staging");
1909 assert_eq!(item.unwrap().content, "staging content");
1910
1911 let table_names = db.db.table_names().execute().await.unwrap();
1913 assert!(
1914 !table_names.contains(&"items_migrated".to_string()),
1915 "Staging table should be dropped"
1916 );
1917 }
1918
1919 #[tokio::test]
1920 #[ignore] async fn test_recover_case_b_both_old_schema() {
1922 let tmp = tempfile::TempDir::new().unwrap();
1923 let db_path = tmp.path().join("data");
1924
1925 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1927 .execute()
1928 .await
1929 .unwrap();
1930
1931 let old_schema = Arc::new(old_item_schema());
1933 let batch = old_item_batch("old-id", "old content");
1934 let batches = RecordBatchIterator::new(vec![Ok(batch)], old_schema);
1935 db_conn
1936 .create_table("items", Box::new(batches))
1937 .execute()
1938 .await
1939 .unwrap();
1940
1941 let new_schema = Arc::new(item_schema());
1943 db_conn
1944 .create_empty_table("items_migrated", new_schema)
1945 .execute()
1946 .await
1947 .unwrap();
1948 drop(db_conn);
1949
1950 let embedder = Arc::new(Embedder::new().unwrap());
1952 let db = Database::open_with_embedder(&db_path, None, embedder)
1953 .await
1954 .unwrap();
1955
1956 let needs_migration = db.check_needs_migration().await.unwrap();
1958 assert!(!needs_migration, "Should have migrated after recovery");
1959
1960 let item = db.get_item("old-id").await.unwrap();
1962 assert!(
1963 item.is_some(),
1964 "Item should be preserved through recovery + migration"
1965 );
1966
1967 let table_names = db.db.table_names().execute().await.unwrap();
1969 assert!(
1970 !table_names.contains(&"items_migrated".to_string()),
1971 "Staging table should be dropped"
1972 );
1973 }
1974
1975 #[tokio::test]
1976 #[ignore] async fn test_recover_case_c_both_new_schema() {
1978 let tmp = tempfile::TempDir::new().unwrap();
1979 let db_path = tmp.path().join("data");
1980
1981 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1983 .execute()
1984 .await
1985 .unwrap();
1986
1987 let new_schema = Arc::new(item_schema());
1988
1989 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1991 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1992 let vector = FixedSizeListArray::try_new(
1993 vector_field,
1994 EMBEDDING_DIM as i32,
1995 Arc::new(vector_values),
1996 None,
1997 )
1998 .unwrap();
1999
2000 let batch = RecordBatch::try_new(
2001 new_schema.clone(),
2002 vec![
2003 Arc::new(StringArray::from(vec!["new-id"])),
2004 Arc::new(StringArray::from(vec!["new content"])),
2005 Arc::new(StringArray::from(vec![None::<&str>])),
2006 Arc::new(BooleanArray::from(vec![false])),
2007 Arc::new(Int64Array::from(vec![1700000000i64])),
2008 Arc::new(vector),
2009 ],
2010 )
2011 .unwrap();
2012
2013 let batches = RecordBatchIterator::new(vec![Ok(batch)], new_schema.clone());
2014 db_conn
2015 .create_table("items", Box::new(batches))
2016 .execute()
2017 .await
2018 .unwrap();
2019
2020 db_conn
2022 .create_empty_table("items_migrated", new_schema)
2023 .execute()
2024 .await
2025 .unwrap();
2026 drop(db_conn);
2027
2028 let embedder = Arc::new(Embedder::new().unwrap());
2030 let db = Database::open_with_embedder(&db_path, None, embedder)
2031 .await
2032 .unwrap();
2033
2034 let item = db.get_item("new-id").await.unwrap();
2036 assert!(item.is_some(), "Item should be untouched");
2037 assert_eq!(item.unwrap().content, "new content");
2038
2039 let table_names = db.db.table_names().execute().await.unwrap();
2041 assert!(
2042 !table_names.contains(&"items_migrated".to_string()),
2043 "Staging table should be dropped"
2044 );
2045 }
2046
2047 #[tokio::test]
2048 #[ignore] async fn test_list_items_rejects_invalid_project_id() {
2050 let tmp = tempfile::TempDir::new().unwrap();
2051 let db_path = tmp.path().join("data");
2052 let malicious_pid = "'; DROP TABLE items;--".to_string();
2053
2054 let mut db = Database::open_with_project(&db_path, Some(malicious_pid))
2055 .await
2056 .unwrap();
2057
2058 let result = db
2059 .list_items(ItemFilters::new(), Some(10), crate::ListScope::Project)
2060 .await;
2061
2062 assert!(result.is_err(), "Should reject invalid project_id");
2063 let err_msg = result.unwrap_err().to_string();
2064 assert!(
2065 err_msg.contains("Invalid project_id"),
2066 "Error should mention invalid project_id, got: {}",
2067 err_msg
2068 );
2069 }
2070}