1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25}
26
27fn is_valid_id(s: &str) -> bool {
31 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
32}
33
34use arrow_array::{
35 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
36 RecordBatchIterator, StringArray,
37};
38use arrow_schema::{DataType, Field, Schema};
39use chrono::{TimeZone, Utc};
40use futures::TryStreamExt;
41use lancedb::Table;
42use lancedb::connect;
43use lancedb::query::{ExecutableQuery, QueryBase};
44use tracing::{debug, info};
45
46use crate::boost_similarity;
47use crate::chunker::{ChunkingConfig, chunk_content};
48use crate::document::ContentType;
49use crate::embedder::{EMBEDDING_DIM, Embedder};
50use crate::error::{Result, SedimentError};
51use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
52
53const CHUNK_THRESHOLD: usize = 1000;
55
56const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
58
59const CONFLICT_SEARCH_LIMIT: usize = 5;
61
62const MAX_CHUNKS_PER_ITEM: usize = 200;
66
67pub struct Database {
69 db: lancedb::Connection,
70 embedder: Arc<Embedder>,
71 project_id: Option<String>,
72 items_table: Option<Table>,
73 chunks_table: Option<Table>,
74}
75
76#[derive(Debug, Default, Clone)]
78pub struct DatabaseStats {
79 pub item_count: usize,
80 pub chunk_count: usize,
81}
82
83const SCHEMA_VERSION: i32 = 2;
85
86fn item_schema() -> Schema {
88 Schema::new(vec![
89 Field::new("id", DataType::Utf8, false),
90 Field::new("content", DataType::Utf8, false),
91 Field::new("project_id", DataType::Utf8, true),
92 Field::new("is_chunked", DataType::Boolean, false),
93 Field::new("created_at", DataType::Int64, false), Field::new(
95 "vector",
96 DataType::FixedSizeList(
97 Arc::new(Field::new("item", DataType::Float32, true)),
98 EMBEDDING_DIM as i32,
99 ),
100 false,
101 ),
102 ])
103}
104
105fn chunk_schema() -> Schema {
106 Schema::new(vec![
107 Field::new("id", DataType::Utf8, false),
108 Field::new("item_id", DataType::Utf8, false),
109 Field::new("chunk_index", DataType::Int32, false),
110 Field::new("content", DataType::Utf8, false),
111 Field::new("context", DataType::Utf8, true),
112 Field::new(
113 "vector",
114 DataType::FixedSizeList(
115 Arc::new(Field::new("item", DataType::Float32, true)),
116 EMBEDDING_DIM as i32,
117 ),
118 false,
119 ),
120 ])
121}
122
123impl Database {
124 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
126 Self::open_with_project(path, None).await
127 }
128
129 pub async fn open_with_project(
131 path: impl Into<PathBuf>,
132 project_id: Option<String>,
133 ) -> Result<Self> {
134 let embedder = Arc::new(Embedder::new()?);
135 Self::open_with_embedder(path, project_id, embedder).await
136 }
137
138 pub async fn open_with_embedder(
150 path: impl Into<PathBuf>,
151 project_id: Option<String>,
152 embedder: Arc<Embedder>,
153 ) -> Result<Self> {
154 let path = path.into();
155 info!("Opening database at {:?}", path);
156
157 if let Some(parent) = path.parent() {
159 std::fs::create_dir_all(parent).map_err(|e| {
160 SedimentError::Database(format!("Failed to create database directory: {}", e))
161 })?;
162 }
163
164 let db = connect(path.to_str().ok_or_else(|| {
165 SedimentError::Database("Database path contains invalid UTF-8".to_string())
166 })?)
167 .execute()
168 .await
169 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
170
171 let mut database = Self {
172 db,
173 embedder,
174 project_id,
175 items_table: None,
176 chunks_table: None,
177 };
178
179 database.ensure_tables().await?;
180 database.ensure_vector_index().await?;
181
182 Ok(database)
183 }
184
185 pub fn set_project_id(&mut self, project_id: Option<String>) {
187 self.project_id = project_id;
188 }
189
190 pub fn project_id(&self) -> Option<&str> {
192 self.project_id.as_deref()
193 }
194
195 async fn ensure_tables(&mut self) -> Result<()> {
197 let table_names = self
199 .db
200 .table_names()
201 .execute()
202 .await
203 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
204
205 if table_names.contains(&"items".to_string()) {
207 let needs_migration = self.check_needs_migration().await?;
208 if needs_migration {
209 info!("Migrating database schema to version {}", SCHEMA_VERSION);
210 self.migrate_schema().await?;
211 }
212 }
213
214 if table_names.contains(&"items".to_string()) {
216 self.items_table =
217 Some(self.db.open_table("items").execute().await.map_err(|e| {
218 SedimentError::Database(format!("Failed to open items: {}", e))
219 })?);
220 }
221
222 if table_names.contains(&"chunks".to_string()) {
224 self.chunks_table =
225 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
226 SedimentError::Database(format!("Failed to open chunks: {}", e))
227 })?);
228 }
229
230 Ok(())
231 }
232
233 async fn check_needs_migration(&self) -> Result<bool> {
235 let table = self.db.open_table("items").execute().await.map_err(|e| {
236 SedimentError::Database(format!("Failed to open items for check: {}", e))
237 })?;
238
239 let schema = table
240 .schema()
241 .await
242 .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
243
244 let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
246 Ok(has_tags)
247 }
248
249 async fn migrate_schema(&mut self) -> Result<()> {
251 info!("Starting schema migration...");
252
253 let old_table = self
255 .db
256 .open_table("items")
257 .execute()
258 .await
259 .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
260
261 let results = old_table
263 .query()
264 .execute()
265 .await
266 .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
267 .try_collect::<Vec<_>>()
268 .await
269 .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
270
271 let mut new_batches = Vec::new();
273 for batch in &results {
274 let converted = self.convert_batch_to_new_schema(batch)?;
275 new_batches.push(converted);
276 }
277
278 let item_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
279 info!("Migrating {} items to new schema", item_count);
280
281 self.db.drop_table("items").await.map_err(|e| {
283 SedimentError::Database(format!("Failed to drop old items table: {}", e))
284 })?;
285
286 let schema = Arc::new(item_schema());
288 let new_table = self
289 .db
290 .create_empty_table("items", schema.clone())
291 .execute()
292 .await
293 .map_err(|e| {
294 SedimentError::Database(format!("Failed to create new items table: {}", e))
295 })?;
296
297 if !new_batches.is_empty() {
299 let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema);
300 new_table
301 .add(Box::new(batches))
302 .execute()
303 .await
304 .map_err(|e| {
305 SedimentError::Database(format!("Failed to insert migrated items: {}", e))
306 })?;
307 }
308
309 info!("Schema migration completed successfully");
310 Ok(())
311 }
312
313 fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
315 let schema = Arc::new(item_schema());
316
317 let id_col = batch
319 .column_by_name("id")
320 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
321 .clone();
322
323 let content_col = batch
324 .column_by_name("content")
325 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
326 .clone();
327
328 let project_id_col = batch
329 .column_by_name("project_id")
330 .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
331 .clone();
332
333 let is_chunked_col = batch
334 .column_by_name("is_chunked")
335 .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
336 .clone();
337
338 let created_at_col = batch
339 .column_by_name("created_at")
340 .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
341 .clone();
342
343 let vector_col = batch
344 .column_by_name("vector")
345 .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
346 .clone();
347
348 RecordBatch::try_new(
349 schema,
350 vec![
351 id_col,
352 content_col,
353 project_id_col,
354 is_chunked_col,
355 created_at_col,
356 vector_col,
357 ],
358 )
359 .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
360 }
361
362 async fn ensure_vector_index(&self) -> Result<()> {
367 const MIN_ROWS_FOR_INDEX: usize = 256;
368
369 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
370 if let Some(table) = table_opt {
371 let row_count = table.count_rows(None).await.unwrap_or(0);
372 if row_count < MIN_ROWS_FOR_INDEX {
373 continue;
374 }
375
376 let indices = table.list_indices().await.unwrap_or_default();
378
379 let has_vector_index = indices
380 .iter()
381 .any(|idx| idx.columns.contains(&"vector".to_string()));
382
383 if !has_vector_index {
384 info!(
385 "Creating vector index on {} table ({} rows)",
386 name, row_count
387 );
388 match table
389 .create_index(&["vector"], lancedb::index::Index::Auto)
390 .execute()
391 .await
392 {
393 Ok(_) => info!("Vector index created on {} table", name),
394 Err(e) => {
395 tracing::warn!("Failed to create vector index on {}: {}", name, e);
397 }
398 }
399 }
400 }
401 }
402
403 Ok(())
404 }
405
406 async fn get_items_table(&mut self) -> Result<&Table> {
408 if self.items_table.is_none() {
409 let schema = Arc::new(item_schema());
410 let table = self
411 .db
412 .create_empty_table("items", schema)
413 .execute()
414 .await
415 .map_err(|e| {
416 SedimentError::Database(format!("Failed to create items table: {}", e))
417 })?;
418 self.items_table = Some(table);
419 }
420 Ok(self.items_table.as_ref().unwrap())
421 }
422
423 async fn get_chunks_table(&mut self) -> Result<&Table> {
425 if self.chunks_table.is_none() {
426 let schema = Arc::new(chunk_schema());
427 let table = self
428 .db
429 .create_empty_table("chunks", schema)
430 .execute()
431 .await
432 .map_err(|e| {
433 SedimentError::Database(format!("Failed to create chunks table: {}", e))
434 })?;
435 self.chunks_table = Some(table);
436 }
437 Ok(self.chunks_table.as_ref().unwrap())
438 }
439
440 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
447 if item.project_id.is_none() {
449 item.project_id = self.project_id.clone();
450 }
451
452 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
455 item.is_chunked = should_chunk;
456
457 let embedding_text = item.embedding_text();
459 let embedding = self.embedder.embed(&embedding_text)?;
460 item.embedding = embedding;
461
462 let table = self.get_items_table().await?;
464 let batch = item_to_batch(&item)?;
465 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
466
467 table
468 .add(Box::new(batches))
469 .execute()
470 .await
471 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
472
473 if should_chunk {
475 let embedder = self.embedder.clone();
476 let chunks_table = self.get_chunks_table().await?;
477
478 let content_type = detect_content_type(&item.content);
480 let config = ChunkingConfig::default();
481 let mut chunk_results = chunk_content(&item.content, content_type, &config);
482
483 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
485 tracing::warn!(
486 "Chunk count {} exceeds limit {}, truncating",
487 chunk_results.len(),
488 MAX_CHUNKS_PER_ITEM
489 );
490 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
491 }
492
493 for (i, chunk_result) in chunk_results.iter().enumerate() {
494 let mut chunk = Chunk::new(&item.id, i, &chunk_result.content);
495
496 if let Some(ctx) = &chunk_result.context {
497 chunk = chunk.with_context(ctx);
498 }
499
500 let chunk_embedding = embedder.embed(&chunk.content)?;
501 chunk.embedding = chunk_embedding;
502
503 let chunk_batch = chunk_to_batch(&chunk)?;
504 let batches =
505 RecordBatchIterator::new(vec![Ok(chunk_batch)], Arc::new(chunk_schema()));
506
507 chunks_table
508 .add(Box::new(batches))
509 .execute()
510 .await
511 .map_err(|e| {
512 SedimentError::Database(format!("Failed to store chunk: {}", e))
513 })?;
514 }
515
516 debug!(
517 "Stored item: {} with {} chunks",
518 item.id,
519 chunk_results.len()
520 );
521 } else {
522 debug!("Stored item: {} (no chunking)", item.id);
523 }
524
525 let potential_conflicts = self
527 .find_similar_items(
528 &item.content,
529 CONFLICT_SIMILARITY_THRESHOLD,
530 CONFLICT_SEARCH_LIMIT,
531 )
532 .await
533 .unwrap_or_default()
534 .into_iter()
535 .filter(|c| c.id != item.id)
536 .collect();
537
538 Ok(StoreResult {
539 id: item.id,
540 potential_conflicts,
541 })
542 }
543
544 pub async fn search_items(
546 &mut self,
547 query: &str,
548 limit: usize,
549 filters: ItemFilters,
550 ) -> Result<Vec<SearchResult>> {
551 let limit = limit.min(1000);
553 self.ensure_vector_index().await?;
555
556 let query_embedding = self.embedder.embed(query)?;
558 let min_similarity = filters.min_similarity.unwrap_or(0.3);
559
560 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
562 std::collections::HashMap::new();
563
564 if let Some(table) = &self.items_table {
566 let query_builder = table
567 .vector_search(query_embedding.clone())
568 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
569 .limit(limit * 2);
570
571 let results = query_builder
572 .execute()
573 .await
574 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
575 .try_collect::<Vec<_>>()
576 .await
577 .map_err(|e| {
578 SedimentError::Database(format!("Failed to collect results: {}", e))
579 })?;
580
581 for batch in results {
582 let items = batch_to_items(&batch)?;
583 let distances = batch
584 .column_by_name("_distance")
585 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
586
587 for (i, item) in items.into_iter().enumerate() {
588 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
589 let similarity = 1.0 / (1.0 + distance);
590
591 if similarity < min_similarity {
592 continue;
593 }
594
595 let boosted_similarity = boost_similarity(
597 similarity,
598 item.project_id.as_deref(),
599 self.project_id.as_deref(),
600 );
601
602 let result = SearchResult::from_item(&item, boosted_similarity);
603 results_map
604 .entry(item.id.clone())
605 .or_insert((result, boosted_similarity));
606 }
607 }
608 }
609
610 if let Some(chunks_table) = &self.chunks_table {
612 let chunk_results = chunks_table
613 .vector_search(query_embedding)
614 .map_err(|e| {
615 SedimentError::Database(format!("Failed to build chunk search: {}", e))
616 })?
617 .limit(limit * 3)
618 .execute()
619 .await
620 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
621 .try_collect::<Vec<_>>()
622 .await
623 .map_err(|e| {
624 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
625 })?;
626
627 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
629 std::collections::HashMap::new();
630
631 for batch in chunk_results {
632 let chunks = batch_to_chunks(&batch)?;
633 let distances = batch
634 .column_by_name("_distance")
635 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
636
637 for (i, chunk) in chunks.into_iter().enumerate() {
638 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
639 let similarity = 1.0 / (1.0 + distance);
640
641 if similarity < min_similarity {
642 continue;
643 }
644
645 chunk_matches
647 .entry(chunk.item_id.clone())
648 .and_modify(|(content, best_sim)| {
649 if similarity > *best_sim {
650 *content = chunk.content.clone();
651 *best_sim = similarity;
652 }
653 })
654 .or_insert((chunk.content.clone(), similarity));
655 }
656 }
657
658 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
660 if let Some(item) = self.get_item(&item_id).await? {
661 let boosted_similarity = boost_similarity(
663 chunk_similarity,
664 item.project_id.as_deref(),
665 self.project_id.as_deref(),
666 );
667
668 let result =
669 SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
670
671 results_map
673 .entry(item_id)
674 .and_modify(|(existing, existing_sim)| {
675 if boosted_similarity > *existing_sim {
676 *existing = result.clone();
677 *existing_sim = boosted_similarity;
678 }
679 })
680 .or_insert((result, boosted_similarity));
681 }
682 }
683 }
684
685 let mut search_results: Vec<SearchResult> =
687 results_map.into_values().map(|(r, _)| r).collect();
688 search_results.sort_by(|a, b| {
689 b.similarity
690 .partial_cmp(&a.similarity)
691 .unwrap_or(std::cmp::Ordering::Equal)
692 });
693 search_results.truncate(limit);
694
695 Ok(search_results)
696 }
697
698 pub async fn find_similar_items(
703 &mut self,
704 content: &str,
705 min_similarity: f32,
706 limit: usize,
707 ) -> Result<Vec<ConflictInfo>> {
708 let embedding = self.embedder.embed(content)?;
710
711 let table = match &self.items_table {
712 Some(t) => t,
713 None => return Ok(Vec::new()),
714 };
715
716 let results = table
717 .vector_search(embedding)
718 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
719 .limit(limit)
720 .execute()
721 .await
722 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
723 .try_collect::<Vec<_>>()
724 .await
725 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
726
727 let mut conflicts = Vec::new();
728
729 for batch in results {
730 let items = batch_to_items(&batch)?;
731 let distances = batch
732 .column_by_name("_distance")
733 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
734
735 for (i, item) in items.into_iter().enumerate() {
736 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
737 let similarity = 1.0 / (1.0 + distance);
738
739 if similarity >= min_similarity {
740 conflicts.push(ConflictInfo {
741 id: item.id,
742 content: item.content,
743 similarity,
744 });
745 }
746 }
747 }
748
749 conflicts.sort_by(|a, b| {
751 b.similarity
752 .partial_cmp(&a.similarity)
753 .unwrap_or(std::cmp::Ordering::Equal)
754 });
755
756 Ok(conflicts)
757 }
758
759 pub async fn list_items(
761 &mut self,
762 _filters: ItemFilters,
763 limit: Option<usize>,
764 scope: crate::ListScope,
765 ) -> Result<Vec<Item>> {
766 let table = match &self.items_table {
767 Some(t) => t,
768 None => return Ok(Vec::new()),
769 };
770
771 let mut filter_parts = Vec::new();
772
773 match scope {
775 crate::ListScope::Project => {
776 if let Some(ref pid) = self.project_id {
777 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
778 } else {
779 return Ok(Vec::new());
781 }
782 }
783 crate::ListScope::Global => {
784 filter_parts.push("project_id IS NULL".to_string());
785 }
786 crate::ListScope::All => {
787 }
789 }
790
791 let mut query = table.query();
792
793 if !filter_parts.is_empty() {
794 let filter_str = filter_parts.join(" AND ");
795 query = query.only_if(filter_str);
796 }
797
798 if let Some(l) = limit {
799 query = query.limit(l);
800 }
801
802 let results = query
803 .execute()
804 .await
805 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
806 .try_collect::<Vec<_>>()
807 .await
808 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
809
810 let mut items = Vec::new();
811 for batch in results {
812 items.extend(batch_to_items(&batch)?);
813 }
814
815 Ok(items)
816 }
817
818 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
820 if !is_valid_id(id) {
821 return Ok(None);
822 }
823 let table = match &self.items_table {
824 Some(t) => t,
825 None => return Ok(None),
826 };
827
828 let results = table
829 .query()
830 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
831 .limit(1)
832 .execute()
833 .await
834 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
835 .try_collect::<Vec<_>>()
836 .await
837 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
838
839 for batch in results {
840 let items = batch_to_items(&batch)?;
841 if let Some(item) = items.into_iter().next() {
842 return Ok(Some(item));
843 }
844 }
845
846 Ok(None)
847 }
848
849 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
851 let table = match &self.items_table {
852 Some(t) => t,
853 None => return Ok(Vec::new()),
854 };
855
856 if ids.is_empty() {
857 return Ok(Vec::new());
858 }
859
860 let quoted: Vec<String> = ids
861 .iter()
862 .filter(|id| is_valid_id(id))
863 .map(|id| format!("'{}'", sanitize_sql_string(id)))
864 .collect();
865 if quoted.is_empty() {
866 return Ok(Vec::new());
867 }
868 let filter = format!("id IN ({})", quoted.join(", "));
869
870 let results = table
871 .query()
872 .only_if(filter)
873 .execute()
874 .await
875 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
876 .try_collect::<Vec<_>>()
877 .await
878 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
879
880 let mut items = Vec::new();
881 for batch in results {
882 items.extend(batch_to_items(&batch)?);
883 }
884
885 Ok(items)
886 }
887
888 pub async fn delete_item(&self, id: &str) -> Result<bool> {
891 if !is_valid_id(id) {
892 return Ok(false);
893 }
894 let table = match &self.items_table {
896 Some(t) => t,
897 None => return Ok(false),
898 };
899
900 let exists = self.get_item(id).await?.is_some();
901 if !exists {
902 return Ok(false);
903 }
904
905 if let Some(chunks_table) = &self.chunks_table {
907 chunks_table
908 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
909 .await
910 .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
911 }
912
913 table
915 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
916 .await
917 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
918
919 Ok(true)
920 }
921
922 pub async fn stats(&self) -> Result<DatabaseStats> {
924 let mut stats = DatabaseStats::default();
925
926 if let Some(table) = &self.items_table {
927 stats.item_count = table
928 .count_rows(None)
929 .await
930 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
931 }
932
933 if let Some(table) = &self.chunks_table {
934 stats.chunk_count = table
935 .count_rows(None)
936 .await
937 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
938 }
939
940 Ok(stats)
941 }
942}
943
944pub fn score_with_decay(
955 similarity: f32,
956 now: i64,
957 created_at: i64,
958 access_count: u32,
959 last_accessed_at: Option<i64>,
960) -> f32 {
961 if !similarity.is_finite() {
963 return 0.0;
964 }
965
966 let reference_time = last_accessed_at.unwrap_or(created_at);
967 let age_secs = (now - reference_time).max(0) as f64;
968 let age_days = age_secs / 86400.0;
969
970 let freshness = 1.0 / (1.0 + age_days / 30.0);
971 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
972
973 let result = similarity * (freshness * frequency) as f32;
974 if result.is_finite() { result } else { 0.0 }
975}
976
977fn detect_content_type(content: &str) -> ContentType {
981 let trimmed = content.trim();
982
983 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
985 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
986 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
987 {
988 return ContentType::Json;
989 }
990
991 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
995 let lines: Vec<&str> = trimmed.lines().take(10).collect();
996 let yaml_key_count = lines
997 .iter()
998 .filter(|line| {
999 let l = line.trim();
1000 !l.is_empty()
1003 && !l.starts_with('#')
1004 && !l.contains("://")
1005 && l.contains(": ")
1006 && l.split(": ").next().is_some_and(|key| {
1007 let k = key.trim_start_matches("- ");
1008 !k.is_empty()
1009 && k.chars()
1010 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1011 })
1012 })
1013 .count();
1014 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1016 return ContentType::Yaml;
1017 }
1018 }
1019
1020 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1022 return ContentType::Markdown;
1023 }
1024
1025 let code_patterns = [
1028 "fn ",
1029 "pub fn ",
1030 "def ",
1031 "class ",
1032 "function ",
1033 "const ",
1034 "let ",
1035 "var ",
1036 "import ",
1037 "export ",
1038 "struct ",
1039 "impl ",
1040 "trait ",
1041 ];
1042 let has_code_pattern = trimmed.lines().any(|line| {
1043 let l = line.trim();
1044 code_patterns.iter().any(|p| l.starts_with(p))
1045 });
1046 if has_code_pattern {
1047 return ContentType::Code;
1048 }
1049
1050 ContentType::Text
1051}
1052
1053fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1056 let schema = Arc::new(item_schema());
1057
1058 let id = StringArray::from(vec![item.id.as_str()]);
1059 let content = StringArray::from(vec![item.content.as_str()]);
1060 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1061 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1062 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1063
1064 let vector = create_embedding_array(&item.embedding)?;
1065
1066 RecordBatch::try_new(
1067 schema,
1068 vec![
1069 Arc::new(id),
1070 Arc::new(content),
1071 Arc::new(project_id),
1072 Arc::new(is_chunked),
1073 Arc::new(created_at),
1074 Arc::new(vector),
1075 ],
1076 )
1077 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1078}
1079
1080fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1081 let mut items = Vec::new();
1082
1083 let id_col = batch
1084 .column_by_name("id")
1085 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1086 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1087
1088 let content_col = batch
1089 .column_by_name("content")
1090 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1091 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1092
1093 let project_id_col = batch
1094 .column_by_name("project_id")
1095 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1096
1097 let is_chunked_col = batch
1098 .column_by_name("is_chunked")
1099 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1100
1101 let created_at_col = batch
1102 .column_by_name("created_at")
1103 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1104
1105 let vector_col = batch
1106 .column_by_name("vector")
1107 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1108
1109 for i in 0..batch.num_rows() {
1110 let id = id_col.value(i).to_string();
1111 let content = content_col.value(i).to_string();
1112
1113 let project_id = project_id_col.and_then(|c| {
1114 if c.is_null(i) {
1115 None
1116 } else {
1117 Some(c.value(i).to_string())
1118 }
1119 });
1120
1121 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1122
1123 let created_at = created_at_col
1124 .map(|c| {
1125 Utc.timestamp_opt(c.value(i), 0)
1126 .single()
1127 .unwrap_or_else(Utc::now)
1128 })
1129 .unwrap_or_else(Utc::now);
1130
1131 let embedding = vector_col
1132 .and_then(|col| {
1133 let value = col.value(i);
1134 value
1135 .as_any()
1136 .downcast_ref::<Float32Array>()
1137 .map(|arr| arr.values().to_vec())
1138 })
1139 .unwrap_or_default();
1140
1141 let item = Item {
1142 id,
1143 content,
1144 embedding,
1145 project_id,
1146 is_chunked,
1147 created_at,
1148 };
1149
1150 items.push(item);
1151 }
1152
1153 Ok(items)
1154}
1155
1156fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1157 let schema = Arc::new(chunk_schema());
1158
1159 let id = StringArray::from(vec![chunk.id.as_str()]);
1160 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1161 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1162 let content = StringArray::from(vec![chunk.content.as_str()]);
1163 let context = StringArray::from(vec![chunk.context.as_deref()]);
1164
1165 let vector = create_embedding_array(&chunk.embedding)?;
1166
1167 RecordBatch::try_new(
1168 schema,
1169 vec![
1170 Arc::new(id),
1171 Arc::new(item_id),
1172 Arc::new(chunk_index),
1173 Arc::new(content),
1174 Arc::new(context),
1175 Arc::new(vector),
1176 ],
1177 )
1178 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1179}
1180
1181fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1182 let mut chunks = Vec::new();
1183
1184 let id_col = batch
1185 .column_by_name("id")
1186 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1187 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1188
1189 let item_id_col = batch
1190 .column_by_name("item_id")
1191 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1192 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1193
1194 let chunk_index_col = batch
1195 .column_by_name("chunk_index")
1196 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1197 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1198
1199 let content_col = batch
1200 .column_by_name("content")
1201 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1202 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1203
1204 let context_col = batch
1205 .column_by_name("context")
1206 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1207
1208 for i in 0..batch.num_rows() {
1209 let id = id_col.value(i).to_string();
1210 let item_id = item_id_col.value(i).to_string();
1211 let chunk_index = chunk_index_col.value(i) as usize;
1212 let content = content_col.value(i).to_string();
1213 let context = context_col.and_then(|c| {
1214 if c.is_null(i) {
1215 None
1216 } else {
1217 Some(c.value(i).to_string())
1218 }
1219 });
1220
1221 let chunk = Chunk {
1222 id,
1223 item_id,
1224 chunk_index,
1225 content,
1226 embedding: Vec::new(),
1227 context,
1228 };
1229
1230 chunks.push(chunk);
1231 }
1232
1233 Ok(chunks)
1234}
1235
1236fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1237 let values = Float32Array::from(embedding.to_vec());
1238 let field = Arc::new(Field::new("item", DataType::Float32, true));
1239
1240 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1241 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246 use super::*;
1247
1248 #[test]
1249 fn test_score_with_decay_fresh_item() {
1250 let now = 1700000000i64;
1251 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1253 let expected = 0.8 * 1.0 * 1.0;
1255 assert!((score - expected).abs() < 0.001, "got {}", score);
1256 }
1257
1258 #[test]
1259 fn test_score_with_decay_30_day_old() {
1260 let now = 1700000000i64;
1261 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1263 let expected = 0.8 * 0.5;
1265 assert!((score - expected).abs() < 0.001, "got {}", score);
1266 }
1267
1268 #[test]
1269 fn test_score_with_decay_frequent_access() {
1270 let now = 1700000000i64;
1271 let created = now - 30 * 86400;
1272 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1274 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1276 let expected = 0.8 * 1.0 * freq as f32;
1277 assert!((score - expected).abs() < 0.01, "got {}", score);
1278 }
1279
1280 #[test]
1281 fn test_score_with_decay_old_and_unused() {
1282 let now = 1700000000i64;
1283 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1285 let expected = 0.8 * 0.25;
1287 assert!((score - expected).abs() < 0.001, "got {}", score);
1288 }
1289
1290 #[test]
1291 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1292 assert_eq!(sanitize_sql_string("hello"), "hello");
1293 assert_eq!(sanitize_sql_string("it's"), "it''s");
1294 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1295 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1296 }
1297
1298 #[test]
1299 fn test_sanitize_sql_string_strips_null_bytes() {
1300 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1301 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1302 assert_eq!(sanitize_sql_string("clean"), "clean");
1303 }
1304
1305 #[test]
1306 fn test_sanitize_sql_string_strips_semicolons() {
1307 assert_eq!(
1308 sanitize_sql_string("a; DROP TABLE items"),
1309 "a DROP TABLE items"
1310 );
1311 assert_eq!(sanitize_sql_string("normal;"), "normal");
1312 }
1313
1314 #[test]
1315 fn test_sanitize_sql_string_strips_comments() {
1316 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1318 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block */");
1320 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1322 }
1323
1324 #[test]
1325 fn test_sanitize_sql_string_adversarial_inputs() {
1326 assert_eq!(
1328 sanitize_sql_string("'; DROP TABLE items;--"),
1329 "'' DROP TABLE items"
1330 );
1331 assert_eq!(
1333 sanitize_sql_string("hello\u{200B}world"),
1334 "hello\u{200B}world"
1335 );
1336 assert_eq!(sanitize_sql_string(""), "");
1338 assert_eq!(sanitize_sql_string("\0;\0"), "");
1340 }
1341
1342 #[test]
1343 fn test_is_valid_id() {
1344 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1346 assert!(is_valid_id("abcdef0123456789"));
1347 assert!(!is_valid_id(""));
1349 assert!(!is_valid_id("'; DROP TABLE items;--"));
1350 assert!(!is_valid_id("hello world"));
1351 assert!(!is_valid_id("abc\0def"));
1352 assert!(!is_valid_id(&"a".repeat(65)));
1354 }
1355
1356 #[test]
1357 fn test_detect_content_type_yaml_not_prose() {
1358 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1360 let detected = detect_content_type(prose);
1361 assert_ne!(
1362 detected,
1363 ContentType::Yaml,
1364 "Prose with colons should not be detected as YAML"
1365 );
1366
1367 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1369 let detected = detect_content_type(yaml);
1370 assert_eq!(detected, ContentType::Yaml);
1371 }
1372
1373 #[test]
1374 fn test_detect_content_type_yaml_with_separator() {
1375 let yaml = "---\nname: test\nversion: 1.0";
1376 let detected = detect_content_type(yaml);
1377 assert_eq!(detected, ContentType::Yaml);
1378 }
1379
1380 #[test]
1381 fn test_chunk_threshold_uses_chars_not_bytes() {
1382 let emoji_content = "😀".repeat(500);
1385 assert_eq!(emoji_content.chars().count(), 500);
1386 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1389 assert!(
1390 !should_chunk,
1391 "500 chars should not exceed 1000-char threshold"
1392 );
1393
1394 let long_content = "a".repeat(1001);
1396 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1397 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1398 }
1399
1400 #[test]
1401 fn test_schema_version() {
1402 let version = SCHEMA_VERSION;
1404 assert!(version >= 2, "Schema version should be at least 2");
1405 }
1406}