1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25 .replace("*/", "")
26}
27
28fn is_valid_id(s: &str) -> bool {
32 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
33}
34
35use arrow_array::{
36 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
37 RecordBatchIterator, StringArray,
38};
39use arrow_schema::{DataType, Field, Schema};
40use chrono::{TimeZone, Utc};
41use futures::TryStreamExt;
42use lancedb::Table;
43use lancedb::connect;
44use lancedb::index::scalar::FullTextSearchQuery;
45use lancedb::query::{ExecutableQuery, QueryBase};
46use tracing::{debug, info};
47
48use crate::boost_similarity;
49use crate::chunker::{ChunkingConfig, chunk_content};
50use crate::document::ContentType;
51use crate::embedder::{EMBEDDING_DIM, Embedder};
52use crate::error::{Result, SedimentError};
53use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
54
55const CHUNK_THRESHOLD: usize = 1000;
57
58const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
60
61const CONFLICT_SEARCH_LIMIT: usize = 5;
63
64const MAX_CHUNKS_PER_ITEM: usize = 200;
68
69const EMBEDDING_BATCH_SIZE: usize = 32;
72
73const FTS_BOOST_MAX: f32 = 0.12;
78
79const FTS_GAMMA: f32 = 2.0;
84
85const VECTOR_INDEX_THRESHOLD: usize = 5000;
88
89pub struct Database {
91 db: lancedb::Connection,
92 embedder: Arc<Embedder>,
93 project_id: Option<String>,
94 items_table: Option<Table>,
95 chunks_table: Option<Table>,
96 fts_boost_max: f32,
98 fts_gamma: f32,
99}
100
101#[derive(Debug, Default, Clone)]
103pub struct DatabaseStats {
104 pub item_count: usize,
105 pub chunk_count: usize,
106}
107
108const SCHEMA_VERSION: i32 = 2;
110
111fn item_schema() -> Schema {
113 Schema::new(vec![
114 Field::new("id", DataType::Utf8, false),
115 Field::new("content", DataType::Utf8, false),
116 Field::new("project_id", DataType::Utf8, true),
117 Field::new("is_chunked", DataType::Boolean, false),
118 Field::new("created_at", DataType::Int64, false), Field::new(
120 "vector",
121 DataType::FixedSizeList(
122 Arc::new(Field::new("item", DataType::Float32, true)),
123 EMBEDDING_DIM as i32,
124 ),
125 false,
126 ),
127 ])
128}
129
130fn chunk_schema() -> Schema {
131 Schema::new(vec![
132 Field::new("id", DataType::Utf8, false),
133 Field::new("item_id", DataType::Utf8, false),
134 Field::new("chunk_index", DataType::Int32, false),
135 Field::new("content", DataType::Utf8, false),
136 Field::new("context", DataType::Utf8, true),
137 Field::new(
138 "vector",
139 DataType::FixedSizeList(
140 Arc::new(Field::new("item", DataType::Float32, true)),
141 EMBEDDING_DIM as i32,
142 ),
143 false,
144 ),
145 ])
146}
147
148impl Database {
149 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
151 Self::open_with_project(path, None).await
152 }
153
154 pub async fn open_with_project(
156 path: impl Into<PathBuf>,
157 project_id: Option<String>,
158 ) -> Result<Self> {
159 let embedder = Arc::new(Embedder::new()?);
160 Self::open_with_embedder(path, project_id, embedder).await
161 }
162
163 pub async fn open_with_embedder(
175 path: impl Into<PathBuf>,
176 project_id: Option<String>,
177 embedder: Arc<Embedder>,
178 ) -> Result<Self> {
179 let path = path.into();
180 info!("Opening database at {:?}", path);
181
182 if let Some(parent) = path.parent() {
184 std::fs::create_dir_all(parent).map_err(|e| {
185 SedimentError::Database(format!("Failed to create database directory: {}", e))
186 })?;
187 }
188
189 let db = connect(path.to_str().ok_or_else(|| {
190 SedimentError::Database("Database path contains invalid UTF-8".to_string())
191 })?)
192 .execute()
193 .await
194 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
195
196 #[cfg(feature = "bench")]
199 let (fts_boost_max, fts_gamma) = {
200 let boost = std::env::var("SEDIMENT_FTS_BOOST_MAX")
201 .ok()
202 .and_then(|v| v.parse::<f32>().ok())
203 .unwrap_or(FTS_BOOST_MAX);
204 let gamma = std::env::var("SEDIMENT_FTS_GAMMA")
205 .ok()
206 .and_then(|v| v.parse::<f32>().ok())
207 .unwrap_or(FTS_GAMMA);
208 (boost, gamma)
209 };
210 #[cfg(not(feature = "bench"))]
211 let (fts_boost_max, fts_gamma) = (FTS_BOOST_MAX, FTS_GAMMA);
212
213 let mut database = Self {
214 db,
215 embedder,
216 project_id,
217 items_table: None,
218 chunks_table: None,
219 fts_boost_max,
220 fts_gamma,
221 };
222
223 database.ensure_tables().await?;
224 database.ensure_vector_index().await?;
225
226 Ok(database)
227 }
228
229 pub fn set_project_id(&mut self, project_id: Option<String>) {
231 self.project_id = project_id;
232 }
233
234 pub fn project_id(&self) -> Option<&str> {
236 self.project_id.as_deref()
237 }
238
239 async fn ensure_tables(&mut self) -> Result<()> {
241 let mut table_names = self
243 .db
244 .table_names()
245 .execute()
246 .await
247 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
248
249 if table_names.contains(&"items_migrated".to_string()) {
251 info!("Detected interrupted migration, recovering...");
252 self.recover_interrupted_migration(&table_names).await?;
253 table_names =
255 self.db.table_names().execute().await.map_err(|e| {
256 SedimentError::Database(format!("Failed to list tables: {}", e))
257 })?;
258 }
259
260 if table_names.contains(&"items".to_string()) {
262 let needs_migration = self.check_needs_migration().await?;
263 if needs_migration {
264 info!("Migrating database schema to version {}", SCHEMA_VERSION);
265 self.migrate_schema().await?;
266 }
267 }
268
269 if table_names.contains(&"items".to_string()) {
271 self.items_table =
272 Some(self.db.open_table("items").execute().await.map_err(|e| {
273 SedimentError::Database(format!("Failed to open items: {}", e))
274 })?);
275 }
276
277 if table_names.contains(&"chunks".to_string()) {
279 self.chunks_table =
280 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
281 SedimentError::Database(format!("Failed to open chunks: {}", e))
282 })?);
283 }
284
285 Ok(())
286 }
287
288 async fn check_needs_migration(&self) -> Result<bool> {
290 let table = self.db.open_table("items").execute().await.map_err(|e| {
291 SedimentError::Database(format!("Failed to open items for check: {}", e))
292 })?;
293
294 let schema = table
295 .schema()
296 .await
297 .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
298
299 let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
301 Ok(has_tags)
302 }
303
304 async fn recover_interrupted_migration(&mut self, table_names: &[String]) -> Result<()> {
315 let has_items = table_names.contains(&"items".to_string());
316
317 if !has_items {
318 info!("Recovery case A: restoring items from items_migrated");
320 let staging = self
321 .db
322 .open_table("items_migrated")
323 .execute()
324 .await
325 .map_err(|e| {
326 SedimentError::Database(format!("Failed to open staging table: {}", e))
327 })?;
328
329 let results = staging
330 .query()
331 .execute()
332 .await
333 .map_err(|e| SedimentError::Database(format!("Recovery query failed: {}", e)))?
334 .try_collect::<Vec<_>>()
335 .await
336 .map_err(|e| SedimentError::Database(format!("Recovery collect failed: {}", e)))?;
337
338 let schema = Arc::new(item_schema());
339 let new_table = self
340 .db
341 .create_empty_table("items", schema.clone())
342 .execute()
343 .await
344 .map_err(|e| {
345 SedimentError::Database(format!("Failed to create items table: {}", e))
346 })?;
347
348 if !results.is_empty() {
349 let batches = RecordBatchIterator::new(results.into_iter().map(Ok), schema);
350 new_table
351 .add(Box::new(batches))
352 .execute()
353 .await
354 .map_err(|e| {
355 SedimentError::Database(format!("Failed to restore items: {}", e))
356 })?;
357 }
358
359 self.db.drop_table("items_migrated").await.map_err(|e| {
360 SedimentError::Database(format!("Failed to drop staging table: {}", e))
361 })?;
362 info!("Recovery case A completed");
363 } else {
364 let has_old_schema = self.check_needs_migration().await?;
366
367 if has_old_schema {
368 info!("Recovery case B: dropping incomplete staging table");
370 self.db.drop_table("items_migrated").await.map_err(|e| {
371 SedimentError::Database(format!("Failed to drop staging table: {}", e))
372 })?;
373 } else {
375 info!("Recovery case C: dropping leftover staging table");
377 self.db.drop_table("items_migrated").await.map_err(|e| {
378 SedimentError::Database(format!("Failed to drop staging table: {}", e))
379 })?;
380 }
381 }
382
383 Ok(())
384 }
385
386 async fn migrate_schema(&mut self) -> Result<()> {
400 info!("Starting schema migration...");
401
402 let old_table = self
404 .db
405 .open_table("items")
406 .execute()
407 .await
408 .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
409
410 let results = old_table
411 .query()
412 .execute()
413 .await
414 .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
415 .try_collect::<Vec<_>>()
416 .await
417 .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
418
419 let mut new_batches = Vec::new();
421 for batch in &results {
422 let converted = self.convert_batch_to_new_schema(batch)?;
423 new_batches.push(converted);
424 }
425
426 let old_count: usize = results.iter().map(|b| b.num_rows()).sum();
428 let new_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
429 if old_count != new_count {
430 return Err(SedimentError::Database(format!(
431 "Migration row count mismatch: old={}, new={}",
432 old_count, new_count
433 )));
434 }
435 info!("Migrating {} items to new schema", old_count);
436
437 let table_names = self
439 .db
440 .table_names()
441 .execute()
442 .await
443 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
444 if table_names.contains(&"items_migrated".to_string()) {
445 self.db.drop_table("items_migrated").await.map_err(|e| {
446 SedimentError::Database(format!("Failed to drop stale staging: {}", e))
447 })?;
448 }
449
450 let schema = Arc::new(item_schema());
452 let staging_table = self
453 .db
454 .create_empty_table("items_migrated", schema.clone())
455 .execute()
456 .await
457 .map_err(|e| {
458 SedimentError::Database(format!("Failed to create staging table: {}", e))
459 })?;
460
461 if !new_batches.is_empty() {
462 let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema.clone());
463 staging_table
464 .add(Box::new(batches))
465 .execute()
466 .await
467 .map_err(|e| {
468 SedimentError::Database(format!("Failed to insert into staging: {}", e))
469 })?;
470 }
471
472 let staging_count = staging_table
474 .count_rows(None)
475 .await
476 .map_err(|e| SedimentError::Database(format!("Failed to count staging rows: {}", e)))?;
477 if staging_count != old_count {
478 let _ = self.db.drop_table("items_migrated").await;
480 return Err(SedimentError::Database(format!(
481 "Staging row count mismatch: expected {}, got {}",
482 old_count, staging_count
483 )));
484 }
485
486 self.db.drop_table("items").await.map_err(|e| {
488 SedimentError::Database(format!("Failed to drop old items table: {}", e))
489 })?;
490
491 let staging_data = staging_table
493 .query()
494 .execute()
495 .await
496 .map_err(|e| SedimentError::Database(format!("Failed to read staging: {}", e)))?
497 .try_collect::<Vec<_>>()
498 .await
499 .map_err(|e| SedimentError::Database(format!("Failed to collect staging: {}", e)))?;
500
501 let new_table = self
502 .db
503 .create_empty_table("items", schema.clone())
504 .execute()
505 .await
506 .map_err(|e| {
507 SedimentError::Database(format!("Failed to create new items table: {}", e))
508 })?;
509
510 if !staging_data.is_empty() {
511 let batches = RecordBatchIterator::new(staging_data.into_iter().map(Ok), schema);
512 new_table
513 .add(Box::new(batches))
514 .execute()
515 .await
516 .map_err(|e| {
517 SedimentError::Database(format!("Failed to insert migrated items: {}", e))
518 })?;
519 }
520
521 self.db
523 .drop_table("items_migrated")
524 .await
525 .map_err(|e| SedimentError::Database(format!("Failed to drop staging table: {}", e)))?;
526
527 info!("Schema migration completed successfully");
528 Ok(())
529 }
530
531 fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
533 let schema = Arc::new(item_schema());
534
535 let id_col = batch
537 .column_by_name("id")
538 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
539 .clone();
540
541 let content_col = batch
542 .column_by_name("content")
543 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
544 .clone();
545
546 let project_id_col = batch
547 .column_by_name("project_id")
548 .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
549 .clone();
550
551 let is_chunked_col = batch
552 .column_by_name("is_chunked")
553 .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
554 .clone();
555
556 let created_at_col = batch
557 .column_by_name("created_at")
558 .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
559 .clone();
560
561 let vector_col = batch
562 .column_by_name("vector")
563 .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
564 .clone();
565
566 RecordBatch::try_new(
567 schema,
568 vec![
569 id_col,
570 content_col,
571 project_id_col,
572 is_chunked_col,
573 created_at_col,
574 vector_col,
575 ],
576 )
577 .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
578 }
579
580 async fn ensure_vector_index(&self) -> Result<()> {
586 const MIN_ROWS_FOR_INDEX: usize = 256;
587
588 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
589 if let Some(table) = table_opt {
590 let row_count = table.count_rows(None).await.unwrap_or(0);
591
592 let indices = table.list_indices().await.unwrap_or_default();
594
595 if row_count >= MIN_ROWS_FOR_INDEX {
597 let has_vector_index = indices
598 .iter()
599 .any(|idx| idx.columns.contains(&"vector".to_string()));
600
601 if !has_vector_index {
602 info!(
603 "Creating vector index on {} table ({} rows)",
604 name, row_count
605 );
606 match table
607 .create_index(&["vector"], lancedb::index::Index::Auto)
608 .execute()
609 .await
610 {
611 Ok(_) => info!("Vector index created on {} table", name),
612 Err(e) => {
613 tracing::warn!("Failed to create vector index on {}: {}", name, e);
615 }
616 }
617 }
618 }
619
620 if row_count > 0 {
622 let has_fts_index = indices
623 .iter()
624 .any(|idx| idx.columns.contains(&"content".to_string()));
625
626 if !has_fts_index {
627 info!("Creating FTS index on {} table ({} rows)", name, row_count);
628 match table
629 .create_index(
630 &["content"],
631 lancedb::index::Index::FTS(Default::default()),
632 )
633 .execute()
634 .await
635 {
636 Ok(_) => info!("FTS index created on {} table", name),
637 Err(e) => {
638 tracing::warn!("Failed to create FTS index on {}: {}", name, e);
640 }
641 }
642 }
643 }
644 }
645 }
646
647 Ok(())
648 }
649
650 async fn get_items_table(&mut self) -> Result<&Table> {
652 if self.items_table.is_none() {
653 let schema = Arc::new(item_schema());
654 let table = self
655 .db
656 .create_empty_table("items", schema)
657 .execute()
658 .await
659 .map_err(|e| {
660 SedimentError::Database(format!("Failed to create items table: {}", e))
661 })?;
662 self.items_table = Some(table);
663 }
664 Ok(self.items_table.as_ref().unwrap())
665 }
666
667 async fn get_chunks_table(&mut self) -> Result<&Table> {
669 if self.chunks_table.is_none() {
670 let schema = Arc::new(chunk_schema());
671 let table = self
672 .db
673 .create_empty_table("chunks", schema)
674 .execute()
675 .await
676 .map_err(|e| {
677 SedimentError::Database(format!("Failed to create chunks table: {}", e))
678 })?;
679 self.chunks_table = Some(table);
680 }
681 Ok(self.chunks_table.as_ref().unwrap())
682 }
683
684 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
691 if item.project_id.is_none() {
693 item.project_id = self.project_id.clone();
694 }
695
696 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
699 item.is_chunked = should_chunk;
700
701 let embedding_text = item.embedding_text();
703 let embedding = self.embedder.embed(&embedding_text)?;
704 item.embedding = embedding;
705
706 let table = self.get_items_table().await?;
708 let batch = item_to_batch(&item)?;
709 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
710
711 table
712 .add(Box::new(batches))
713 .execute()
714 .await
715 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
716
717 if should_chunk {
719 let content_type = detect_content_type(&item.content);
720 let config = ChunkingConfig::default();
721 let mut chunk_results = chunk_content(&item.content, content_type, &config);
722
723 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
725 tracing::warn!(
726 "Chunk count {} exceeds limit {}, truncating",
727 chunk_results.len(),
728 MAX_CHUNKS_PER_ITEM
729 );
730 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
731 }
732
733 if let Err(e) = self.store_chunks(&item.id, &chunk_results).await {
734 let _ = self.delete_item(&item.id).await;
736 return Err(e);
737 }
738
739 debug!(
740 "Stored item: {} with {} chunks",
741 item.id,
742 chunk_results.len()
743 );
744 } else {
745 debug!("Stored item: {} (no chunking)", item.id);
746 }
747
748 let potential_conflicts = self
751 .find_similar_items_by_vector(
752 &item.embedding,
753 Some(&item.id),
754 CONFLICT_SIMILARITY_THRESHOLD,
755 CONFLICT_SEARCH_LIMIT,
756 )
757 .await
758 .unwrap_or_default();
759
760 Ok(StoreResult {
761 id: item.id,
762 potential_conflicts,
763 })
764 }
765
766 async fn store_chunks(
768 &mut self,
769 item_id: &str,
770 chunk_results: &[crate::chunker::ChunkResult],
771 ) -> Result<()> {
772 let embedder = self.embedder.clone();
773 let chunks_table = self.get_chunks_table().await?;
774
775 let chunk_texts: Vec<&str> = chunk_results.iter().map(|cr| cr.content.as_str()).collect();
777 let mut all_embeddings = Vec::with_capacity(chunk_texts.len());
778 for batch_start in (0..chunk_texts.len()).step_by(EMBEDDING_BATCH_SIZE) {
779 let batch_end = (batch_start + EMBEDDING_BATCH_SIZE).min(chunk_texts.len());
780 let batch_embeddings = embedder.embed_batch(&chunk_texts[batch_start..batch_end])?;
781 all_embeddings.extend(batch_embeddings);
782 }
783
784 let mut all_chunk_batches = Vec::with_capacity(chunk_results.len());
786 for (i, (chunk_result, embedding)) in chunk_results.iter().zip(all_embeddings).enumerate() {
787 let mut chunk = Chunk::new(item_id, i, &chunk_result.content);
788 if let Some(ctx) = &chunk_result.context {
789 chunk = chunk.with_context(ctx);
790 }
791 chunk.embedding = embedding;
792 all_chunk_batches.push(chunk_to_batch(&chunk)?);
793 }
794
795 if !all_chunk_batches.is_empty() {
797 let schema = Arc::new(chunk_schema());
798 let batches = RecordBatchIterator::new(all_chunk_batches.into_iter().map(Ok), schema);
799 chunks_table
800 .add(Box::new(batches))
801 .execute()
802 .await
803 .map_err(|e| SedimentError::Database(format!("Failed to store chunks: {}", e)))?;
804 }
805
806 Ok(())
807 }
808
809 async fn fts_rank_items(
812 &self,
813 table: &Table,
814 query: &str,
815 limit: usize,
816 ) -> Option<std::collections::HashMap<String, f32>> {
817 let fts_query =
818 FullTextSearchQuery::new(query.to_string()).columns(Some(vec!["content".to_string()]));
819
820 let fts_results = table
821 .query()
822 .full_text_search(fts_query)
823 .limit(limit)
824 .execute()
825 .await
826 .ok()?
827 .try_collect::<Vec<_>>()
828 .await
829 .ok()?;
830
831 let mut scores = std::collections::HashMap::new();
832 for batch in fts_results {
833 let ids = batch
834 .column_by_name("id")
835 .and_then(|c| c.as_any().downcast_ref::<StringArray>())?;
836 let bm25_scores = batch
837 .column_by_name("_score")
838 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
839 for i in 0..ids.len() {
840 if !ids.is_null(i) {
841 let score = bm25_scores.map(|s| s.value(i)).unwrap_or(0.0);
842 scores.insert(ids.value(i).to_string(), score);
843 }
844 }
845 }
846 Some(scores)
847 }
848
849 pub async fn search_items(
851 &mut self,
852 query: &str,
853 limit: usize,
854 filters: ItemFilters,
855 ) -> Result<Vec<SearchResult>> {
856 let limit = limit.min(1000);
858 self.ensure_vector_index().await?;
860
861 let query_embedding = self.embedder.embed(query)?;
863 let min_similarity = filters.min_similarity.unwrap_or(0.3);
864
865 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
867 std::collections::HashMap::new();
868
869 if let Some(table) = &self.items_table {
871 let row_count = table.count_rows(None).await.unwrap_or(0);
872 let base_query = table
873 .vector_search(query_embedding.clone())
874 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
875 let query_builder = if row_count < VECTOR_INDEX_THRESHOLD {
876 base_query.bypass_vector_index().limit(limit * 2)
877 } else {
878 base_query.refine_factor(10).limit(limit * 2)
879 };
880
881 let results = query_builder
882 .execute()
883 .await
884 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
885 .try_collect::<Vec<_>>()
886 .await
887 .map_err(|e| {
888 SedimentError::Database(format!("Failed to collect results: {}", e))
889 })?;
890
891 let mut vector_items: Vec<(Item, f32)> = Vec::new();
893 for batch in results {
894 let items = batch_to_items(&batch)?;
895 let distances = batch
896 .column_by_name("_distance")
897 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
898
899 for (i, item) in items.into_iter().enumerate() {
900 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
901 let similarity = 1.0 / (1.0 + distance);
902 if similarity >= min_similarity {
903 vector_items.push((item, similarity));
904 }
905 }
906 }
907
908 let fts_ranking = self.fts_rank_items(table, query, limit * 2).await;
910
911 let max_bm25 = fts_ranking
913 .as_ref()
914 .and_then(|scores| scores.values().cloned().reduce(f32::max))
915 .unwrap_or(1.0)
916 .max(f32::EPSILON);
917
918 for (item, similarity) in vector_items {
924 let fts_boost = fts_ranking.as_ref().map_or(0.0, |scores| {
925 scores.get(&item.id).map_or(0.0, |&bm25_score| {
926 self.fts_boost_max * (bm25_score / max_bm25).powf(self.fts_gamma)
927 })
928 });
929 let boosted_similarity = boost_similarity(
930 similarity + fts_boost,
931 item.project_id.as_deref(),
932 self.project_id.as_deref(),
933 );
934
935 let result = SearchResult::from_item(&item, boosted_similarity);
936 results_map
937 .entry(item.id.clone())
938 .or_insert((result, boosted_similarity));
939 }
940 }
941
942 if let Some(chunks_table) = &self.chunks_table {
944 let chunk_row_count = chunks_table.count_rows(None).await.unwrap_or(0);
945 let chunk_base_query = chunks_table.vector_search(query_embedding).map_err(|e| {
946 SedimentError::Database(format!("Failed to build chunk search: {}", e))
947 })?;
948 let chunk_results = if chunk_row_count < VECTOR_INDEX_THRESHOLD {
949 chunk_base_query.bypass_vector_index().limit(limit * 3)
950 } else {
951 chunk_base_query.refine_factor(10).limit(limit * 3)
952 }
953 .execute()
954 .await
955 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
956 .try_collect::<Vec<_>>()
957 .await
958 .map_err(|e| {
959 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
960 })?;
961
962 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
964 std::collections::HashMap::new();
965
966 for batch in chunk_results {
967 let chunks = batch_to_chunks(&batch)?;
968 let distances = batch
969 .column_by_name("_distance")
970 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
971
972 for (i, chunk) in chunks.into_iter().enumerate() {
973 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
974 let similarity = 1.0 / (1.0 + distance);
975
976 if similarity < min_similarity {
977 continue;
978 }
979
980 chunk_matches
982 .entry(chunk.item_id.clone())
983 .and_modify(|(content, best_sim)| {
984 if similarity > *best_sim {
985 *content = chunk.content.clone();
986 *best_sim = similarity;
987 }
988 })
989 .or_insert((chunk.content.clone(), similarity));
990 }
991 }
992
993 let chunk_item_ids: Vec<&str> = chunk_matches.keys().map(|id| id.as_str()).collect();
995 let parent_items = self.get_items_batch(&chunk_item_ids).await?;
996 let parent_map: std::collections::HashMap<&str, &Item> = parent_items
997 .iter()
998 .map(|item| (item.id.as_str(), item))
999 .collect();
1000
1001 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
1002 if let Some(item) = parent_map.get(item_id.as_str()) {
1003 let boosted_similarity = boost_similarity(
1005 chunk_similarity,
1006 item.project_id.as_deref(),
1007 self.project_id.as_deref(),
1008 );
1009
1010 let result =
1011 SearchResult::from_item_with_excerpt(item, boosted_similarity, excerpt);
1012
1013 results_map
1015 .entry(item_id)
1016 .and_modify(|(existing, existing_sim)| {
1017 if boosted_similarity > *existing_sim {
1018 *existing = result.clone();
1019 *existing_sim = boosted_similarity;
1020 }
1021 })
1022 .or_insert((result, boosted_similarity));
1023 }
1024 }
1025 }
1026
1027 let mut search_results: Vec<SearchResult> =
1030 results_map.into_values().map(|(sr, _)| sr).collect();
1031 search_results.sort_by(|a, b| {
1032 b.similarity
1033 .partial_cmp(&a.similarity)
1034 .unwrap_or(std::cmp::Ordering::Equal)
1035 });
1036 search_results.truncate(limit);
1037
1038 Ok(search_results)
1039 }
1040
1041 pub async fn find_similar_items(
1046 &mut self,
1047 content: &str,
1048 min_similarity: f32,
1049 limit: usize,
1050 ) -> Result<Vec<ConflictInfo>> {
1051 let embedding = self.embedder.embed(content)?;
1052 self.find_similar_items_by_vector(&embedding, None, min_similarity, limit)
1053 .await
1054 }
1055
1056 pub async fn find_similar_items_by_vector(
1060 &self,
1061 embedding: &[f32],
1062 exclude_id: Option<&str>,
1063 min_similarity: f32,
1064 limit: usize,
1065 ) -> Result<Vec<ConflictInfo>> {
1066 let table = match &self.items_table {
1067 Some(t) => t,
1068 None => return Ok(Vec::new()),
1069 };
1070
1071 let row_count = table.count_rows(None).await.unwrap_or(0);
1072 let base_query = table
1073 .vector_search(embedding.to_vec())
1074 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?;
1075 let results = if row_count < VECTOR_INDEX_THRESHOLD {
1076 base_query.bypass_vector_index().limit(limit)
1077 } else {
1078 base_query.refine_factor(10).limit(limit)
1079 }
1080 .execute()
1081 .await
1082 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
1083 .try_collect::<Vec<_>>()
1084 .await
1085 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
1086
1087 let mut conflicts = Vec::new();
1088
1089 for batch in results {
1090 let items = batch_to_items(&batch)?;
1091 let distances = batch
1092 .column_by_name("_distance")
1093 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
1094
1095 for (i, item) in items.into_iter().enumerate() {
1096 if exclude_id.is_some_and(|eid| eid == item.id) {
1097 continue;
1098 }
1099
1100 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
1101 let similarity = 1.0 / (1.0 + distance);
1102
1103 if similarity >= min_similarity {
1104 conflicts.push(ConflictInfo {
1105 id: item.id,
1106 content: item.content,
1107 similarity,
1108 });
1109 }
1110 }
1111 }
1112
1113 conflicts.sort_by(|a, b| {
1115 b.similarity
1116 .partial_cmp(&a.similarity)
1117 .unwrap_or(std::cmp::Ordering::Equal)
1118 });
1119
1120 Ok(conflicts)
1121 }
1122
1123 pub async fn list_items(
1125 &mut self,
1126 _filters: ItemFilters,
1127 limit: Option<usize>,
1128 scope: crate::ListScope,
1129 ) -> Result<Vec<Item>> {
1130 let table = match &self.items_table {
1131 Some(t) => t,
1132 None => return Ok(Vec::new()),
1133 };
1134
1135 let mut filter_parts = Vec::new();
1136
1137 match scope {
1139 crate::ListScope::Project => {
1140 if let Some(ref pid) = self.project_id {
1141 if !is_valid_id(pid) {
1142 return Err(SedimentError::Database(
1143 "Invalid project_id for list filter".to_string(),
1144 ));
1145 }
1146 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
1147 } else {
1148 return Ok(Vec::new());
1150 }
1151 }
1152 crate::ListScope::Global => {
1153 filter_parts.push("project_id IS NULL".to_string());
1154 }
1155 crate::ListScope::All => {
1156 }
1158 }
1159
1160 let mut query = table.query();
1161
1162 if !filter_parts.is_empty() {
1163 let filter_str = filter_parts.join(" AND ");
1164 query = query.only_if(filter_str);
1165 }
1166
1167 if let Some(l) = limit {
1168 query = query.limit(l);
1169 }
1170
1171 let results = query
1172 .execute()
1173 .await
1174 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1175 .try_collect::<Vec<_>>()
1176 .await
1177 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1178
1179 let mut items = Vec::new();
1180 for batch in results {
1181 items.extend(batch_to_items(&batch)?);
1182 }
1183
1184 Ok(items)
1185 }
1186
1187 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
1189 if !is_valid_id(id) {
1190 return Ok(None);
1191 }
1192 let table = match &self.items_table {
1193 Some(t) => t,
1194 None => return Ok(None),
1195 };
1196
1197 let results = table
1198 .query()
1199 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
1200 .limit(1)
1201 .execute()
1202 .await
1203 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
1204 .try_collect::<Vec<_>>()
1205 .await
1206 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
1207
1208 for batch in results {
1209 let items = batch_to_items(&batch)?;
1210 if let Some(item) = items.into_iter().next() {
1211 return Ok(Some(item));
1212 }
1213 }
1214
1215 Ok(None)
1216 }
1217
1218 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
1220 let table = match &self.items_table {
1221 Some(t) => t,
1222 None => return Ok(Vec::new()),
1223 };
1224
1225 if ids.is_empty() {
1226 return Ok(Vec::new());
1227 }
1228
1229 let quoted: Vec<String> = ids
1230 .iter()
1231 .filter(|id| is_valid_id(id))
1232 .map(|id| format!("'{}'", sanitize_sql_string(id)))
1233 .collect();
1234 if quoted.is_empty() {
1235 return Ok(Vec::new());
1236 }
1237 let filter = format!("id IN ({})", quoted.join(", "));
1238
1239 let results = table
1240 .query()
1241 .only_if(filter)
1242 .execute()
1243 .await
1244 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
1245 .try_collect::<Vec<_>>()
1246 .await
1247 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
1248
1249 let mut items = Vec::new();
1250 for batch in results {
1251 items.extend(batch_to_items(&batch)?);
1252 }
1253
1254 Ok(items)
1255 }
1256
1257 pub async fn delete_item(&self, id: &str) -> Result<bool> {
1260 if !is_valid_id(id) {
1261 return Ok(false);
1262 }
1263 let table = match &self.items_table {
1265 Some(t) => t,
1266 None => return Ok(false),
1267 };
1268
1269 let exists = self.get_item(id).await?.is_some();
1270 if !exists {
1271 return Ok(false);
1272 }
1273
1274 if let Some(chunks_table) = &self.chunks_table {
1276 chunks_table
1277 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
1278 .await
1279 .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
1280 }
1281
1282 table
1284 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
1285 .await
1286 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
1287
1288 Ok(true)
1289 }
1290
1291 pub async fn stats(&self) -> Result<DatabaseStats> {
1293 let mut stats = DatabaseStats::default();
1294
1295 if let Some(table) = &self.items_table {
1296 stats.item_count = table
1297 .count_rows(None)
1298 .await
1299 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1300 }
1301
1302 if let Some(table) = &self.chunks_table {
1303 stats.chunk_count = table
1304 .count_rows(None)
1305 .await
1306 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
1307 }
1308
1309 Ok(stats)
1310 }
1311}
1312
1313pub async fn migrate_project_id(
1320 db_path: &std::path::Path,
1321 old_id: &str,
1322 new_id: &str,
1323) -> Result<u64> {
1324 if !is_valid_id(old_id) || !is_valid_id(new_id) {
1325 return Err(SedimentError::Database(
1326 "Invalid project ID for migration".to_string(),
1327 ));
1328 }
1329
1330 let db = connect(db_path.to_str().ok_or_else(|| {
1331 SedimentError::Database("Database path contains invalid UTF-8".to_string())
1332 })?)
1333 .execute()
1334 .await
1335 .map_err(|e| SedimentError::Database(format!("Failed to connect for migration: {}", e)))?;
1336
1337 let table_names = db
1338 .table_names()
1339 .execute()
1340 .await
1341 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
1342
1343 let mut total_updated = 0u64;
1344
1345 if table_names.contains(&"items".to_string()) {
1346 let table =
1347 db.open_table("items").execute().await.map_err(|e| {
1348 SedimentError::Database(format!("Failed to open items table: {}", e))
1349 })?;
1350
1351 let updated = table
1352 .update()
1353 .only_if(format!("project_id = '{}'", sanitize_sql_string(old_id)))
1354 .column("project_id", format!("'{}'", sanitize_sql_string(new_id)))
1355 .execute()
1356 .await
1357 .map_err(|e| SedimentError::Database(format!("Failed to migrate items: {}", e)))?;
1358
1359 total_updated += updated;
1360 info!(
1361 "Migrated {} items from project {} to {}",
1362 updated, old_id, new_id
1363 );
1364 }
1365
1366 Ok(total_updated)
1367}
1368
1369pub fn score_with_decay(
1380 similarity: f32,
1381 now: i64,
1382 created_at: i64,
1383 access_count: u32,
1384 last_accessed_at: Option<i64>,
1385) -> f32 {
1386 if !similarity.is_finite() {
1388 return 0.0;
1389 }
1390
1391 let reference_time = last_accessed_at.unwrap_or(created_at);
1392 let age_secs = (now - reference_time).max(0) as f64;
1393 let age_days = age_secs / 86400.0;
1394
1395 let freshness = 1.0 / (1.0 + age_days / 30.0);
1396 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
1397
1398 let result = similarity * (freshness * frequency) as f32;
1399 if result.is_finite() { result } else { 0.0 }
1400}
1401
1402fn detect_content_type(content: &str) -> ContentType {
1406 let trimmed = content.trim();
1407
1408 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
1410 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
1411 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
1412 {
1413 return ContentType::Json;
1414 }
1415
1416 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
1420 let lines: Vec<&str> = trimmed.lines().take(10).collect();
1421 let yaml_key_count = lines
1422 .iter()
1423 .filter(|line| {
1424 let l = line.trim();
1425 !l.is_empty()
1428 && !l.starts_with('#')
1429 && !l.contains("://")
1430 && l.contains(": ")
1431 && l.split(": ").next().is_some_and(|key| {
1432 let k = key.trim_start_matches("- ");
1433 !k.is_empty()
1434 && k.chars()
1435 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1436 })
1437 })
1438 .count();
1439 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1441 return ContentType::Yaml;
1442 }
1443 }
1444
1445 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1447 return ContentType::Markdown;
1448 }
1449
1450 let code_patterns = [
1453 "fn ",
1454 "pub fn ",
1455 "def ",
1456 "class ",
1457 "function ",
1458 "const ",
1459 "let ",
1460 "var ",
1461 "import ",
1462 "export ",
1463 "struct ",
1464 "impl ",
1465 "trait ",
1466 ];
1467 let has_code_pattern = trimmed.lines().any(|line| {
1468 let l = line.trim();
1469 code_patterns.iter().any(|p| l.starts_with(p))
1470 });
1471 if has_code_pattern {
1472 return ContentType::Code;
1473 }
1474
1475 ContentType::Text
1476}
1477
1478fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1481 let schema = Arc::new(item_schema());
1482
1483 let id = StringArray::from(vec![item.id.as_str()]);
1484 let content = StringArray::from(vec![item.content.as_str()]);
1485 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1486 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1487 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1488
1489 let vector = create_embedding_array(&item.embedding)?;
1490
1491 RecordBatch::try_new(
1492 schema,
1493 vec![
1494 Arc::new(id),
1495 Arc::new(content),
1496 Arc::new(project_id),
1497 Arc::new(is_chunked),
1498 Arc::new(created_at),
1499 Arc::new(vector),
1500 ],
1501 )
1502 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1503}
1504
1505fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1506 let mut items = Vec::new();
1507
1508 let id_col = batch
1509 .column_by_name("id")
1510 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1511 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1512
1513 let content_col = batch
1514 .column_by_name("content")
1515 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1516 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1517
1518 let project_id_col = batch
1519 .column_by_name("project_id")
1520 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1521
1522 let is_chunked_col = batch
1523 .column_by_name("is_chunked")
1524 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1525
1526 let created_at_col = batch
1527 .column_by_name("created_at")
1528 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1529
1530 let vector_col = batch
1531 .column_by_name("vector")
1532 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1533
1534 for i in 0..batch.num_rows() {
1535 let id = id_col.value(i).to_string();
1536 let content = content_col.value(i).to_string();
1537
1538 let project_id = project_id_col.and_then(|c| {
1539 if c.is_null(i) {
1540 None
1541 } else {
1542 Some(c.value(i).to_string())
1543 }
1544 });
1545
1546 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1547
1548 let created_at = created_at_col
1549 .map(|c| {
1550 Utc.timestamp_opt(c.value(i), 0)
1551 .single()
1552 .unwrap_or_else(Utc::now)
1553 })
1554 .unwrap_or_else(Utc::now);
1555
1556 let embedding = vector_col
1557 .and_then(|col| {
1558 let value = col.value(i);
1559 value
1560 .as_any()
1561 .downcast_ref::<Float32Array>()
1562 .map(|arr| arr.values().to_vec())
1563 })
1564 .unwrap_or_default();
1565
1566 let item = Item {
1567 id,
1568 content,
1569 embedding,
1570 project_id,
1571 is_chunked,
1572 created_at,
1573 };
1574
1575 items.push(item);
1576 }
1577
1578 Ok(items)
1579}
1580
1581fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1582 let schema = Arc::new(chunk_schema());
1583
1584 let id = StringArray::from(vec![chunk.id.as_str()]);
1585 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1586 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1587 let content = StringArray::from(vec![chunk.content.as_str()]);
1588 let context = StringArray::from(vec![chunk.context.as_deref()]);
1589
1590 let vector = create_embedding_array(&chunk.embedding)?;
1591
1592 RecordBatch::try_new(
1593 schema,
1594 vec![
1595 Arc::new(id),
1596 Arc::new(item_id),
1597 Arc::new(chunk_index),
1598 Arc::new(content),
1599 Arc::new(context),
1600 Arc::new(vector),
1601 ],
1602 )
1603 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1604}
1605
1606fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1607 let mut chunks = Vec::new();
1608
1609 let id_col = batch
1610 .column_by_name("id")
1611 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1612 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1613
1614 let item_id_col = batch
1615 .column_by_name("item_id")
1616 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1617 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1618
1619 let chunk_index_col = batch
1620 .column_by_name("chunk_index")
1621 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1622 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1623
1624 let content_col = batch
1625 .column_by_name("content")
1626 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1627 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1628
1629 let context_col = batch
1630 .column_by_name("context")
1631 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1632
1633 for i in 0..batch.num_rows() {
1634 let id = id_col.value(i).to_string();
1635 let item_id = item_id_col.value(i).to_string();
1636 let chunk_index = chunk_index_col.value(i) as usize;
1637 let content = content_col.value(i).to_string();
1638 let context = context_col.and_then(|c| {
1639 if c.is_null(i) {
1640 None
1641 } else {
1642 Some(c.value(i).to_string())
1643 }
1644 });
1645
1646 let chunk = Chunk {
1647 id,
1648 item_id,
1649 chunk_index,
1650 content,
1651 embedding: Vec::new(),
1652 context,
1653 };
1654
1655 chunks.push(chunk);
1656 }
1657
1658 Ok(chunks)
1659}
1660
1661fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1662 let values = Float32Array::from(embedding.to_vec());
1663 let field = Arc::new(Field::new("item", DataType::Float32, true));
1664
1665 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1666 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1667}
1668
1669#[cfg(test)]
1670mod tests {
1671 use super::*;
1672
1673 #[test]
1674 fn test_score_with_decay_fresh_item() {
1675 let now = 1700000000i64;
1676 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1678 let expected = 0.8 * 1.0 * 1.0;
1680 assert!((score - expected).abs() < 0.001, "got {}", score);
1681 }
1682
1683 #[test]
1684 fn test_score_with_decay_30_day_old() {
1685 let now = 1700000000i64;
1686 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1688 let expected = 0.8 * 0.5;
1690 assert!((score - expected).abs() < 0.001, "got {}", score);
1691 }
1692
1693 #[test]
1694 fn test_score_with_decay_frequent_access() {
1695 let now = 1700000000i64;
1696 let created = now - 30 * 86400;
1697 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1699 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1701 let expected = 0.8 * 1.0 * freq as f32;
1702 assert!((score - expected).abs() < 0.01, "got {}", score);
1703 }
1704
1705 #[test]
1706 fn test_score_with_decay_old_and_unused() {
1707 let now = 1700000000i64;
1708 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1710 let expected = 0.8 * 0.25;
1712 assert!((score - expected).abs() < 0.001, "got {}", score);
1713 }
1714
1715 #[test]
1716 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1717 assert_eq!(sanitize_sql_string("hello"), "hello");
1718 assert_eq!(sanitize_sql_string("it's"), "it''s");
1719 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1720 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1721 }
1722
1723 #[test]
1724 fn test_sanitize_sql_string_strips_null_bytes() {
1725 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1726 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1727 assert_eq!(sanitize_sql_string("*/ OR 1=1"), " OR 1=1");
1729 assert_eq!(sanitize_sql_string("clean"), "clean");
1730 }
1731
1732 #[test]
1733 fn test_sanitize_sql_string_strips_semicolons() {
1734 assert_eq!(
1735 sanitize_sql_string("a; DROP TABLE items"),
1736 "a DROP TABLE items"
1737 );
1738 assert_eq!(sanitize_sql_string("normal;"), "normal");
1739 }
1740
1741 #[test]
1742 fn test_sanitize_sql_string_strips_comments() {
1743 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1745 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block ");
1747 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1749 assert_eq!(sanitize_sql_string("injected */ rest"), "injected rest");
1751 assert_eq!(sanitize_sql_string("*/"), "");
1753 }
1754
1755 #[test]
1756 fn test_sanitize_sql_string_adversarial_inputs() {
1757 assert_eq!(
1759 sanitize_sql_string("'; DROP TABLE items;--"),
1760 "'' DROP TABLE items"
1761 );
1762 assert_eq!(
1764 sanitize_sql_string("hello\u{200B}world"),
1765 "hello\u{200B}world"
1766 );
1767 assert_eq!(sanitize_sql_string(""), "");
1769 assert_eq!(sanitize_sql_string("\0;\0"), "");
1771 }
1772
1773 #[test]
1774 fn test_is_valid_id() {
1775 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1777 assert!(is_valid_id("abcdef0123456789"));
1778 assert!(!is_valid_id(""));
1780 assert!(!is_valid_id("'; DROP TABLE items;--"));
1781 assert!(!is_valid_id("hello world"));
1782 assert!(!is_valid_id("abc\0def"));
1783 assert!(!is_valid_id(&"a".repeat(65)));
1785 }
1786
1787 #[test]
1788 fn test_detect_content_type_yaml_not_prose() {
1789 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1791 let detected = detect_content_type(prose);
1792 assert_ne!(
1793 detected,
1794 ContentType::Yaml,
1795 "Prose with colons should not be detected as YAML"
1796 );
1797
1798 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1800 let detected = detect_content_type(yaml);
1801 assert_eq!(detected, ContentType::Yaml);
1802 }
1803
1804 #[test]
1805 fn test_detect_content_type_yaml_with_separator() {
1806 let yaml = "---\nname: test\nversion: 1.0";
1807 let detected = detect_content_type(yaml);
1808 assert_eq!(detected, ContentType::Yaml);
1809 }
1810
1811 #[test]
1812 fn test_chunk_threshold_uses_chars_not_bytes() {
1813 let emoji_content = "😀".repeat(500);
1816 assert_eq!(emoji_content.chars().count(), 500);
1817 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1820 assert!(
1821 !should_chunk,
1822 "500 chars should not exceed 1000-char threshold"
1823 );
1824
1825 let long_content = "a".repeat(1001);
1827 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1828 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1829 }
1830
1831 #[test]
1832 fn test_schema_version() {
1833 let version = SCHEMA_VERSION;
1835 assert!(version >= 2, "Schema version should be at least 2");
1836 }
1837
1838 fn old_item_schema() -> Schema {
1840 Schema::new(vec![
1841 Field::new("id", DataType::Utf8, false),
1842 Field::new("content", DataType::Utf8, false),
1843 Field::new("project_id", DataType::Utf8, true),
1844 Field::new("tags", DataType::Utf8, true), Field::new("is_chunked", DataType::Boolean, false),
1846 Field::new("created_at", DataType::Int64, false),
1847 Field::new(
1848 "vector",
1849 DataType::FixedSizeList(
1850 Arc::new(Field::new("item", DataType::Float32, true)),
1851 EMBEDDING_DIM as i32,
1852 ),
1853 false,
1854 ),
1855 ])
1856 }
1857
1858 fn old_item_batch(id: &str, content: &str) -> RecordBatch {
1860 let schema = Arc::new(old_item_schema());
1861 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
1862 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
1863 let vector = FixedSizeListArray::try_new(
1864 vector_field,
1865 EMBEDDING_DIM as i32,
1866 Arc::new(vector_values),
1867 None,
1868 )
1869 .unwrap();
1870
1871 RecordBatch::try_new(
1872 schema,
1873 vec![
1874 Arc::new(StringArray::from(vec![id])),
1875 Arc::new(StringArray::from(vec![content])),
1876 Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(StringArray::from(vec![None::<&str>])), Arc::new(BooleanArray::from(vec![false])),
1879 Arc::new(Int64Array::from(vec![1700000000i64])),
1880 Arc::new(vector),
1881 ],
1882 )
1883 .unwrap()
1884 }
1885
1886 #[tokio::test]
1887 #[ignore] async fn test_check_needs_migration_detects_old_schema() {
1889 let tmp = tempfile::TempDir::new().unwrap();
1890 let db_path = tmp.path().join("data");
1891
1892 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1894 .execute()
1895 .await
1896 .unwrap();
1897
1898 let schema = Arc::new(old_item_schema());
1899 let batch = old_item_batch("test-id-1", "old content");
1900 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
1901 db_conn
1902 .create_table("items", Box::new(batches))
1903 .execute()
1904 .await
1905 .unwrap();
1906
1907 let db = Database {
1909 db: db_conn,
1910 embedder: Arc::new(Embedder::new().unwrap()),
1911 project_id: None,
1912 items_table: None,
1913 chunks_table: None,
1914 fts_boost_max: FTS_BOOST_MAX,
1915 fts_gamma: FTS_GAMMA,
1916 };
1917
1918 let needs_migration = db.check_needs_migration().await.unwrap();
1919 assert!(
1920 needs_migration,
1921 "Old schema with tags column should need migration"
1922 );
1923 }
1924
1925 #[tokio::test]
1926 #[ignore] async fn test_check_needs_migration_false_for_new_schema() {
1928 let tmp = tempfile::TempDir::new().unwrap();
1929 let db_path = tmp.path().join("data");
1930
1931 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1933 .execute()
1934 .await
1935 .unwrap();
1936
1937 let schema = Arc::new(item_schema());
1938 db_conn
1939 .create_empty_table("items", schema)
1940 .execute()
1941 .await
1942 .unwrap();
1943
1944 let db = Database {
1945 db: db_conn,
1946 embedder: Arc::new(Embedder::new().unwrap()),
1947 project_id: None,
1948 items_table: None,
1949 chunks_table: None,
1950 fts_boost_max: FTS_BOOST_MAX,
1951 fts_gamma: FTS_GAMMA,
1952 };
1953
1954 let needs_migration = db.check_needs_migration().await.unwrap();
1955 assert!(!needs_migration, "New schema should not need migration");
1956 }
1957
1958 #[tokio::test]
1959 #[ignore] async fn test_migrate_schema_preserves_data() {
1961 let tmp = tempfile::TempDir::new().unwrap();
1962 let db_path = tmp.path().join("data");
1963
1964 let db_conn = lancedb::connect(db_path.to_str().unwrap())
1966 .execute()
1967 .await
1968 .unwrap();
1969
1970 let schema = Arc::new(old_item_schema());
1971 let batch1 = old_item_batch("id-aaa", "first item content");
1972 let batch2 = old_item_batch("id-bbb", "second item content");
1973 let batches = RecordBatchIterator::new(vec![Ok(batch1), Ok(batch2)], schema);
1974 db_conn
1975 .create_table("items", Box::new(batches))
1976 .execute()
1977 .await
1978 .unwrap();
1979 drop(db_conn);
1980
1981 let embedder = Arc::new(Embedder::new().unwrap());
1983 let db = Database::open_with_embedder(&db_path, None, embedder)
1984 .await
1985 .unwrap();
1986
1987 let needs_migration = db.check_needs_migration().await.unwrap();
1989 assert!(
1990 !needs_migration,
1991 "Schema should be migrated (no tags column)"
1992 );
1993
1994 let item_a = db.get_item("id-aaa").await.unwrap();
1996 assert!(item_a.is_some(), "Item id-aaa should be preserved");
1997 assert_eq!(item_a.unwrap().content, "first item content");
1998
1999 let item_b = db.get_item("id-bbb").await.unwrap();
2000 assert!(item_b.is_some(), "Item id-bbb should be preserved");
2001 assert_eq!(item_b.unwrap().content, "second item content");
2002
2003 let stats = db.stats().await.unwrap();
2005 assert_eq!(stats.item_count, 2, "Should have 2 items after migration");
2006 }
2007
2008 #[tokio::test]
2009 #[ignore] async fn test_recover_case_a_only_staging() {
2011 let tmp = tempfile::TempDir::new().unwrap();
2012 let db_path = tmp.path().join("data");
2013
2014 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2016 .execute()
2017 .await
2018 .unwrap();
2019
2020 let schema = Arc::new(item_schema());
2021 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2022 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2023 let vector = FixedSizeListArray::try_new(
2024 vector_field,
2025 EMBEDDING_DIM as i32,
2026 Arc::new(vector_values),
2027 None,
2028 )
2029 .unwrap();
2030
2031 let batch = RecordBatch::try_new(
2032 schema.clone(),
2033 vec![
2034 Arc::new(StringArray::from(vec!["staging-id"])),
2035 Arc::new(StringArray::from(vec!["staging content"])),
2036 Arc::new(StringArray::from(vec![None::<&str>])),
2037 Arc::new(BooleanArray::from(vec![false])),
2038 Arc::new(Int64Array::from(vec![1700000000i64])),
2039 Arc::new(vector),
2040 ],
2041 )
2042 .unwrap();
2043
2044 let batches = RecordBatchIterator::new(vec![Ok(batch)], schema);
2045 db_conn
2046 .create_table("items_migrated", Box::new(batches))
2047 .execute()
2048 .await
2049 .unwrap();
2050 drop(db_conn);
2051
2052 let embedder = Arc::new(Embedder::new().unwrap());
2054 let db = Database::open_with_embedder(&db_path, None, embedder)
2055 .await
2056 .unwrap();
2057
2058 let item = db.get_item("staging-id").await.unwrap();
2060 assert!(item.is_some(), "Item should be recovered from staging");
2061 assert_eq!(item.unwrap().content, "staging content");
2062
2063 let table_names = db.db.table_names().execute().await.unwrap();
2065 assert!(
2066 !table_names.contains(&"items_migrated".to_string()),
2067 "Staging table should be dropped"
2068 );
2069 }
2070
2071 #[tokio::test]
2072 #[ignore] async fn test_recover_case_b_both_old_schema() {
2074 let tmp = tempfile::TempDir::new().unwrap();
2075 let db_path = tmp.path().join("data");
2076
2077 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2079 .execute()
2080 .await
2081 .unwrap();
2082
2083 let old_schema = Arc::new(old_item_schema());
2085 let batch = old_item_batch("old-id", "old content");
2086 let batches = RecordBatchIterator::new(vec![Ok(batch)], old_schema);
2087 db_conn
2088 .create_table("items", Box::new(batches))
2089 .execute()
2090 .await
2091 .unwrap();
2092
2093 let new_schema = Arc::new(item_schema());
2095 db_conn
2096 .create_empty_table("items_migrated", new_schema)
2097 .execute()
2098 .await
2099 .unwrap();
2100 drop(db_conn);
2101
2102 let embedder = Arc::new(Embedder::new().unwrap());
2104 let db = Database::open_with_embedder(&db_path, None, embedder)
2105 .await
2106 .unwrap();
2107
2108 let needs_migration = db.check_needs_migration().await.unwrap();
2110 assert!(!needs_migration, "Should have migrated after recovery");
2111
2112 let item = db.get_item("old-id").await.unwrap();
2114 assert!(
2115 item.is_some(),
2116 "Item should be preserved through recovery + migration"
2117 );
2118
2119 let table_names = db.db.table_names().execute().await.unwrap();
2121 assert!(
2122 !table_names.contains(&"items_migrated".to_string()),
2123 "Staging table should be dropped"
2124 );
2125 }
2126
2127 #[tokio::test]
2128 #[ignore] async fn test_recover_case_c_both_new_schema() {
2130 let tmp = tempfile::TempDir::new().unwrap();
2131 let db_path = tmp.path().join("data");
2132
2133 let db_conn = lancedb::connect(db_path.to_str().unwrap())
2135 .execute()
2136 .await
2137 .unwrap();
2138
2139 let new_schema = Arc::new(item_schema());
2140
2141 let vector_values = Float32Array::from(vec![0.0f32; EMBEDDING_DIM]);
2143 let vector_field = Arc::new(Field::new("item", DataType::Float32, true));
2144 let vector = FixedSizeListArray::try_new(
2145 vector_field,
2146 EMBEDDING_DIM as i32,
2147 Arc::new(vector_values),
2148 None,
2149 )
2150 .unwrap();
2151
2152 let batch = RecordBatch::try_new(
2153 new_schema.clone(),
2154 vec![
2155 Arc::new(StringArray::from(vec!["new-id"])),
2156 Arc::new(StringArray::from(vec!["new content"])),
2157 Arc::new(StringArray::from(vec![None::<&str>])),
2158 Arc::new(BooleanArray::from(vec![false])),
2159 Arc::new(Int64Array::from(vec![1700000000i64])),
2160 Arc::new(vector),
2161 ],
2162 )
2163 .unwrap();
2164
2165 let batches = RecordBatchIterator::new(vec![Ok(batch)], new_schema.clone());
2166 db_conn
2167 .create_table("items", Box::new(batches))
2168 .execute()
2169 .await
2170 .unwrap();
2171
2172 db_conn
2174 .create_empty_table("items_migrated", new_schema)
2175 .execute()
2176 .await
2177 .unwrap();
2178 drop(db_conn);
2179
2180 let embedder = Arc::new(Embedder::new().unwrap());
2182 let db = Database::open_with_embedder(&db_path, None, embedder)
2183 .await
2184 .unwrap();
2185
2186 let item = db.get_item("new-id").await.unwrap();
2188 assert!(item.is_some(), "Item should be untouched");
2189 assert_eq!(item.unwrap().content, "new content");
2190
2191 let table_names = db.db.table_names().execute().await.unwrap();
2193 assert!(
2194 !table_names.contains(&"items_migrated".to_string()),
2195 "Staging table should be dropped"
2196 );
2197 }
2198
2199 #[tokio::test]
2200 #[ignore] async fn test_list_items_rejects_invalid_project_id() {
2202 let tmp = tempfile::TempDir::new().unwrap();
2203 let db_path = tmp.path().join("data");
2204 let malicious_pid = "'; DROP TABLE items;--".to_string();
2205
2206 let mut db = Database::open_with_project(&db_path, Some(malicious_pid))
2207 .await
2208 .unwrap();
2209
2210 let result = db
2211 .list_items(ItemFilters::new(), Some(10), crate::ListScope::Project)
2212 .await;
2213
2214 assert!(result.is_err(), "Should reject invalid project_id");
2215 let err_msg = result.unwrap_err().to_string();
2216 assert!(
2217 err_msg.contains("Invalid project_id"),
2218 "Error should mention invalid project_id, got: {}",
2219 err_msg
2220 );
2221 }
2222}