1use std::path::PathBuf;
7use std::sync::Arc;
8
9fn sanitize_sql_string(s: &str) -> String {
19 s.replace('\0', "")
20 .replace('\\', "\\\\")
21 .replace('\'', "''")
22 .replace(';', "")
23 .replace("--", "")
24 .replace("/*", "")
25}
26
27fn is_valid_id(s: &str) -> bool {
31 !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
32}
33
34use arrow_array::{
35 Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
36 RecordBatchIterator, StringArray,
37};
38use arrow_schema::{DataType, Field, Schema};
39use chrono::{TimeZone, Utc};
40use futures::TryStreamExt;
41use lancedb::Table;
42use lancedb::connect;
43use lancedb::query::{ExecutableQuery, QueryBase};
44use tracing::{debug, info};
45
46use crate::boost_similarity;
47use crate::chunker::{ChunkingConfig, chunk_content};
48use crate::document::ContentType;
49use crate::embedder::{EMBEDDING_DIM, Embedder};
50use crate::error::{Result, SedimentError};
51use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
52
53const CHUNK_THRESHOLD: usize = 1000;
55
56const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
58
59const CONFLICT_SEARCH_LIMIT: usize = 5;
61
62const MAX_CHUNKS_PER_ITEM: usize = 200;
66
67pub struct Database {
69 db: lancedb::Connection,
70 embedder: Arc<Embedder>,
71 project_id: Option<String>,
72 items_table: Option<Table>,
73 chunks_table: Option<Table>,
74}
75
76#[derive(Debug, Default, Clone)]
78pub struct DatabaseStats {
79 pub item_count: usize,
80 pub chunk_count: usize,
81}
82
83fn item_schema() -> Schema {
85 Schema::new(vec![
86 Field::new("id", DataType::Utf8, false),
87 Field::new("content", DataType::Utf8, false),
88 Field::new("title", DataType::Utf8, true),
89 Field::new("tags", DataType::Utf8, true), Field::new("source", DataType::Utf8, true),
91 Field::new("metadata", DataType::Utf8, true), Field::new("project_id", DataType::Utf8, true),
93 Field::new("is_chunked", DataType::Boolean, false),
94 Field::new("expires_at", DataType::Int64, true), Field::new("created_at", DataType::Int64, false), Field::new(
97 "vector",
98 DataType::FixedSizeList(
99 Arc::new(Field::new("item", DataType::Float32, true)),
100 EMBEDDING_DIM as i32,
101 ),
102 false,
103 ),
104 ])
105}
106
107fn chunk_schema() -> Schema {
108 Schema::new(vec![
109 Field::new("id", DataType::Utf8, false),
110 Field::new("item_id", DataType::Utf8, false),
111 Field::new("chunk_index", DataType::Int32, false),
112 Field::new("content", DataType::Utf8, false),
113 Field::new("context", DataType::Utf8, true),
114 Field::new(
115 "vector",
116 DataType::FixedSizeList(
117 Arc::new(Field::new("item", DataType::Float32, true)),
118 EMBEDDING_DIM as i32,
119 ),
120 false,
121 ),
122 ])
123}
124
125impl Database {
126 pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
128 Self::open_with_project(path, None).await
129 }
130
131 pub async fn open_with_project(
133 path: impl Into<PathBuf>,
134 project_id: Option<String>,
135 ) -> Result<Self> {
136 let embedder = Arc::new(Embedder::new()?);
137 Self::open_with_embedder(path, project_id, embedder).await
138 }
139
140 pub async fn open_with_embedder(
152 path: impl Into<PathBuf>,
153 project_id: Option<String>,
154 embedder: Arc<Embedder>,
155 ) -> Result<Self> {
156 let path = path.into();
157 info!("Opening database at {:?}", path);
158
159 if let Some(parent) = path.parent() {
161 std::fs::create_dir_all(parent).map_err(|e| {
162 SedimentError::Database(format!("Failed to create database directory: {}", e))
163 })?;
164 }
165
166 let db = connect(path.to_str().ok_or_else(|| {
167 SedimentError::Database("Database path contains invalid UTF-8".to_string())
168 })?)
169 .execute()
170 .await
171 .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
172
173 let mut database = Self {
174 db,
175 embedder,
176 project_id,
177 items_table: None,
178 chunks_table: None,
179 };
180
181 database.ensure_tables().await?;
182 database.ensure_vector_index().await?;
183
184 Ok(database)
185 }
186
187 pub fn set_project_id(&mut self, project_id: Option<String>) {
189 self.project_id = project_id;
190 }
191
192 pub fn project_id(&self) -> Option<&str> {
194 self.project_id.as_deref()
195 }
196
197 async fn ensure_tables(&mut self) -> Result<()> {
199 let table_names = self
201 .db
202 .table_names()
203 .execute()
204 .await
205 .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
206
207 if table_names.contains(&"items".to_string()) {
209 self.items_table =
210 Some(self.db.open_table("items").execute().await.map_err(|e| {
211 SedimentError::Database(format!("Failed to open items: {}", e))
212 })?);
213 }
214
215 if table_names.contains(&"chunks".to_string()) {
217 self.chunks_table =
218 Some(self.db.open_table("chunks").execute().await.map_err(|e| {
219 SedimentError::Database(format!("Failed to open chunks: {}", e))
220 })?);
221 }
222
223 Ok(())
224 }
225
226 async fn ensure_vector_index(&self) -> Result<()> {
231 const MIN_ROWS_FOR_INDEX: usize = 256;
232
233 for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
234 if let Some(table) = table_opt {
235 let row_count = table.count_rows(None).await.unwrap_or(0);
236 if row_count < MIN_ROWS_FOR_INDEX {
237 continue;
238 }
239
240 let indices = table.list_indices().await.unwrap_or_default();
242
243 let has_vector_index = indices
244 .iter()
245 .any(|idx| idx.columns.contains(&"vector".to_string()));
246
247 if !has_vector_index {
248 info!(
249 "Creating vector index on {} table ({} rows)",
250 name, row_count
251 );
252 match table
253 .create_index(&["vector"], lancedb::index::Index::Auto)
254 .execute()
255 .await
256 {
257 Ok(_) => info!("Vector index created on {} table", name),
258 Err(e) => {
259 tracing::warn!("Failed to create vector index on {}: {}", name, e);
261 }
262 }
263 }
264 }
265 }
266
267 Ok(())
268 }
269
270 async fn get_items_table(&mut self) -> Result<&Table> {
272 if self.items_table.is_none() {
273 let schema = Arc::new(item_schema());
274 let table = self
275 .db
276 .create_empty_table("items", schema)
277 .execute()
278 .await
279 .map_err(|e| {
280 SedimentError::Database(format!("Failed to create items table: {}", e))
281 })?;
282 self.items_table = Some(table);
283 }
284 Ok(self.items_table.as_ref().unwrap())
285 }
286
287 async fn get_chunks_table(&mut self) -> Result<&Table> {
289 if self.chunks_table.is_none() {
290 let schema = Arc::new(chunk_schema());
291 let table = self
292 .db
293 .create_empty_table("chunks", schema)
294 .execute()
295 .await
296 .map_err(|e| {
297 SedimentError::Database(format!("Failed to create chunks table: {}", e))
298 })?;
299 self.chunks_table = Some(table);
300 }
301 Ok(self.chunks_table.as_ref().unwrap())
302 }
303
304 pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
311 if item.project_id.is_none() {
313 item.project_id = self.project_id.clone();
314 }
315
316 let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
319 item.is_chunked = should_chunk;
320
321 let embedding_text = item.embedding_text();
323 let embedding = self.embedder.embed(&embedding_text)?;
324 item.embedding = embedding;
325
326 let table = self.get_items_table().await?;
328 let batch = item_to_batch(&item)?;
329 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
330
331 table
332 .add(Box::new(batches))
333 .execute()
334 .await
335 .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
336
337 if should_chunk {
339 let embedder = self.embedder.clone();
340 let chunks_table = self.get_chunks_table().await?;
341
342 let content_type = detect_content_type(&item.content);
344 let config = ChunkingConfig::default();
345 let mut chunk_results = chunk_content(&item.content, content_type, &config);
346
347 if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
349 tracing::warn!(
350 "Chunk count {} exceeds limit {}, truncating",
351 chunk_results.len(),
352 MAX_CHUNKS_PER_ITEM
353 );
354 chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
355 }
356
357 for (i, chunk_result) in chunk_results.iter().enumerate() {
358 let mut chunk = Chunk::new(&item.id, i, &chunk_result.content);
359
360 if let Some(ctx) = &chunk_result.context {
361 chunk = chunk.with_context(ctx);
362 }
363
364 let chunk_embedding = embedder.embed(&chunk.content)?;
365 chunk.embedding = chunk_embedding;
366
367 let chunk_batch = chunk_to_batch(&chunk)?;
368 let batches =
369 RecordBatchIterator::new(vec![Ok(chunk_batch)], Arc::new(chunk_schema()));
370
371 chunks_table
372 .add(Box::new(batches))
373 .execute()
374 .await
375 .map_err(|e| {
376 SedimentError::Database(format!("Failed to store chunk: {}", e))
377 })?;
378 }
379
380 debug!(
381 "Stored item: {} with {} chunks",
382 item.id,
383 chunk_results.len()
384 );
385 } else {
386 debug!("Stored item: {} (no chunking)", item.id);
387 }
388
389 let potential_conflicts = self
391 .find_similar_items(
392 &item.content,
393 CONFLICT_SIMILARITY_THRESHOLD,
394 CONFLICT_SEARCH_LIMIT,
395 )
396 .await
397 .unwrap_or_default()
398 .into_iter()
399 .filter(|c| c.id != item.id)
400 .collect();
401
402 Ok(StoreResult {
403 id: item.id,
404 potential_conflicts,
405 })
406 }
407
408 pub async fn search_items(
410 &mut self,
411 query: &str,
412 limit: usize,
413 filters: ItemFilters,
414 ) -> Result<Vec<SearchResult>> {
415 let limit = limit.min(1000);
417 self.ensure_vector_index().await?;
419
420 let query_embedding = self.embedder.embed(query)?;
422 let min_similarity = filters.min_similarity.unwrap_or(0.3);
423
424 let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
426 std::collections::HashMap::new();
427
428 if let Some(table) = &self.items_table {
430 let mut filter_parts = Vec::new();
431
432 if !filters.include_expired {
433 let now = Utc::now().timestamp();
434 filter_parts.push(format!("(expires_at IS NULL OR expires_at > {})", now));
435 }
436
437 let mut query_builder = table
438 .vector_search(query_embedding.clone())
439 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
440 .limit(limit * 2);
441
442 if !filter_parts.is_empty() {
443 let filter_str = filter_parts.join(" AND ");
444 query_builder = query_builder.only_if(filter_str);
445 }
446
447 let results = query_builder
448 .execute()
449 .await
450 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
451 .try_collect::<Vec<_>>()
452 .await
453 .map_err(|e| {
454 SedimentError::Database(format!("Failed to collect results: {}", e))
455 })?;
456
457 for batch in results {
458 let items = batch_to_items(&batch)?;
459 let distances = batch
460 .column_by_name("_distance")
461 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
462
463 for (i, item) in items.into_iter().enumerate() {
464 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
465 let similarity = 1.0 / (1.0 + distance);
466
467 if similarity < min_similarity {
468 continue;
469 }
470
471 if let Some(ref filter_tags) = filters.tags
473 && !filter_tags.iter().any(|t| item.tags.contains(t))
474 {
475 continue;
476 }
477
478 let boosted_similarity = boost_similarity(
480 similarity,
481 item.project_id.as_deref(),
482 self.project_id.as_deref(),
483 );
484
485 let result = SearchResult::from_item(&item, boosted_similarity);
486 results_map
487 .entry(item.id.clone())
488 .or_insert((result, boosted_similarity));
489 }
490 }
491 }
492
493 if let Some(chunks_table) = &self.chunks_table {
495 let chunk_results = chunks_table
496 .vector_search(query_embedding)
497 .map_err(|e| {
498 SedimentError::Database(format!("Failed to build chunk search: {}", e))
499 })?
500 .limit(limit * 3)
501 .execute()
502 .await
503 .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
504 .try_collect::<Vec<_>>()
505 .await
506 .map_err(|e| {
507 SedimentError::Database(format!("Failed to collect chunk results: {}", e))
508 })?;
509
510 let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
512 std::collections::HashMap::new();
513
514 for batch in chunk_results {
515 let chunks = batch_to_chunks(&batch)?;
516 let distances = batch
517 .column_by_name("_distance")
518 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
519
520 for (i, chunk) in chunks.into_iter().enumerate() {
521 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
522 let similarity = 1.0 / (1.0 + distance);
523
524 if similarity < min_similarity {
525 continue;
526 }
527
528 chunk_matches
530 .entry(chunk.item_id.clone())
531 .and_modify(|(content, best_sim)| {
532 if similarity > *best_sim {
533 *content = chunk.content.clone();
534 *best_sim = similarity;
535 }
536 })
537 .or_insert((chunk.content.clone(), similarity));
538 }
539 }
540
541 for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
543 if let Some(item) = self.get_item(&item_id).await? {
544 if let Some(ref filter_tags) = filters.tags
546 && !filter_tags.iter().any(|t| item.tags.contains(t))
547 {
548 continue;
549 }
550
551 let boosted_similarity = boost_similarity(
553 chunk_similarity,
554 item.project_id.as_deref(),
555 self.project_id.as_deref(),
556 );
557
558 let result =
559 SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
560
561 results_map
563 .entry(item_id)
564 .and_modify(|(existing, existing_sim)| {
565 if boosted_similarity > *existing_sim {
566 *existing = result.clone();
567 *existing_sim = boosted_similarity;
568 }
569 })
570 .or_insert((result, boosted_similarity));
571 }
572 }
573 }
574
575 let mut search_results: Vec<SearchResult> =
577 results_map.into_values().map(|(r, _)| r).collect();
578 search_results.sort_by(|a, b| {
579 b.similarity
580 .partial_cmp(&a.similarity)
581 .unwrap_or(std::cmp::Ordering::Equal)
582 });
583 search_results.truncate(limit);
584
585 Ok(search_results)
586 }
587
588 pub async fn find_similar_items(
593 &mut self,
594 content: &str,
595 min_similarity: f32,
596 limit: usize,
597 ) -> Result<Vec<ConflictInfo>> {
598 let embedding = self.embedder.embed(content)?;
600
601 let table = match &self.items_table {
602 Some(t) => t,
603 None => return Ok(Vec::new()),
604 };
605
606 let now = Utc::now().timestamp();
608 let filter = format!("(expires_at IS NULL OR expires_at > {})", now);
609
610 let results = table
611 .vector_search(embedding)
612 .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
613 .limit(limit)
614 .only_if(filter)
615 .execute()
616 .await
617 .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
618 .try_collect::<Vec<_>>()
619 .await
620 .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
621
622 let mut conflicts = Vec::new();
623
624 for batch in results {
625 let items = batch_to_items(&batch)?;
626 let distances = batch
627 .column_by_name("_distance")
628 .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
629
630 for (i, item) in items.into_iter().enumerate() {
631 let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
632 let similarity = 1.0 / (1.0 + distance);
633
634 if similarity >= min_similarity {
635 conflicts.push(ConflictInfo {
636 id: item.id,
637 content: item.content,
638 similarity,
639 });
640 }
641 }
642 }
643
644 conflicts.sort_by(|a, b| {
646 b.similarity
647 .partial_cmp(&a.similarity)
648 .unwrap_or(std::cmp::Ordering::Equal)
649 });
650
651 Ok(conflicts)
652 }
653
654 pub async fn list_items(
656 &mut self,
657 filters: ItemFilters,
658 limit: Option<usize>,
659 scope: crate::ListScope,
660 ) -> Result<Vec<Item>> {
661 let table = match &self.items_table {
662 Some(t) => t,
663 None => return Ok(Vec::new()),
664 };
665
666 let mut filter_parts = Vec::new();
667
668 if !filters.include_expired {
669 let now = Utc::now().timestamp();
670 filter_parts.push(format!("(expires_at IS NULL OR expires_at > {})", now));
671 }
672
673 match scope {
675 crate::ListScope::Project => {
676 if let Some(ref pid) = self.project_id {
677 filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
678 } else {
679 return Ok(Vec::new());
681 }
682 }
683 crate::ListScope::Global => {
684 filter_parts.push("project_id IS NULL".to_string());
685 }
686 crate::ListScope::All => {
687 }
689 }
690
691 let mut query = table.query();
692
693 if !filter_parts.is_empty() {
694 let filter_str = filter_parts.join(" AND ");
695 query = query.only_if(filter_str);
696 }
697
698 if let Some(l) = limit {
699 query = query.limit(l);
700 }
701
702 let results = query
703 .execute()
704 .await
705 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
706 .try_collect::<Vec<_>>()
707 .await
708 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
709
710 let mut items = Vec::new();
711 for batch in results {
712 items.extend(batch_to_items(&batch)?);
713 }
714
715 if let Some(ref filter_tags) = filters.tags {
717 items.retain(|item| filter_tags.iter().any(|t| item.tags.contains(t)));
718 }
719
720 Ok(items)
721 }
722
723 pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
725 if !is_valid_id(id) {
726 return Ok(None);
727 }
728 let table = match &self.items_table {
729 Some(t) => t,
730 None => return Ok(None),
731 };
732
733 let results = table
734 .query()
735 .only_if(format!("id = '{}'", sanitize_sql_string(id)))
736 .limit(1)
737 .execute()
738 .await
739 .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
740 .try_collect::<Vec<_>>()
741 .await
742 .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
743
744 for batch in results {
745 let items = batch_to_items(&batch)?;
746 if let Some(item) = items.into_iter().next() {
747 return Ok(Some(item));
748 }
749 }
750
751 Ok(None)
752 }
753
754 pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
756 let table = match &self.items_table {
757 Some(t) => t,
758 None => return Ok(Vec::new()),
759 };
760
761 if ids.is_empty() {
762 return Ok(Vec::new());
763 }
764
765 let quoted: Vec<String> = ids
766 .iter()
767 .filter(|id| is_valid_id(id))
768 .map(|id| format!("'{}'", sanitize_sql_string(id)))
769 .collect();
770 if quoted.is_empty() {
771 return Ok(Vec::new());
772 }
773 let filter = format!("id IN ({})", quoted.join(", "));
774
775 let results = table
776 .query()
777 .only_if(filter)
778 .execute()
779 .await
780 .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
781 .try_collect::<Vec<_>>()
782 .await
783 .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
784
785 let mut items = Vec::new();
786 for batch in results {
787 items.extend(batch_to_items(&batch)?);
788 }
789
790 Ok(items)
791 }
792
793 pub async fn expire_item(&self, id: &str, expires_at: chrono::DateTime<Utc>) -> Result<()> {
799 if !is_valid_id(id) {
800 return Err(SedimentError::Database("Invalid item ID".to_string()));
801 }
802 let table = match &self.items_table {
803 Some(t) => t,
804 None => return Err(SedimentError::Database("Items table not found".to_string())),
805 };
806
807 let original_item = self.get_item(id).await?;
809 let original_item = match original_item {
810 Some(i) => i,
811 None => return Err(SedimentError::Database(format!("Item not found: {}", id))),
812 };
813
814 let mut item = original_item.clone();
815 item.expires_at = Some(expires_at);
816
817 table
821 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
822 .await
823 .map_err(|e| SedimentError::Database(format!("Delete for expire failed: {}", e)))?;
824
825 let mut last_err = None;
827 for attempt in 0..3 {
828 let batch = item_to_batch(&item)?;
829 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
830 match table.add(Box::new(batches)).execute().await {
831 Ok(_) => return Ok(()),
832 Err(e) => {
833 tracing::warn!(
834 "Re-insert for expire failed (attempt {}/3): {}",
835 attempt + 1,
836 e
837 );
838 last_err = Some(e);
839 tokio::time::sleep(std::time::Duration::from_millis(100 * (1 << attempt)))
840 .await;
841 }
842 }
843 }
844
845 tracing::error!("expire_item: re-insert failed after 3 attempts, attempting recovery");
847 let batch = item_to_batch(&original_item)?;
848 let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
849 if let Err(recovery_err) = table.add(Box::new(batches)).execute().await {
850 tracing::error!(
851 "expire_item: CRITICAL - recovery also failed, item {} may be lost: {}",
852 id,
853 recovery_err
854 );
855 }
856
857 Err(SedimentError::Database(format!(
858 "Re-insert for expire failed after 3 attempts: {}",
859 last_err.unwrap()
860 )))
861 }
862
863 pub async fn delete_item(&self, id: &str) -> Result<bool> {
866 if !is_valid_id(id) {
867 return Ok(false);
868 }
869 let table = match &self.items_table {
871 Some(t) => t,
872 None => return Ok(false),
873 };
874
875 let exists = self.get_item(id).await?.is_some();
876 if !exists {
877 return Ok(false);
878 }
879
880 if let Some(chunks_table) = &self.chunks_table {
882 chunks_table
883 .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
884 .await
885 .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
886 }
887
888 table
890 .delete(&format!("id = '{}'", sanitize_sql_string(id)))
891 .await
892 .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
893
894 Ok(true)
895 }
896
897 pub async fn stats(&self) -> Result<DatabaseStats> {
899 let mut stats = DatabaseStats::default();
900
901 if let Some(table) = &self.items_table {
902 stats.item_count = table
903 .count_rows(None)
904 .await
905 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
906 }
907
908 if let Some(table) = &self.chunks_table {
909 stats.chunk_count = table
910 .count_rows(None)
911 .await
912 .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
913 }
914
915 Ok(stats)
916 }
917
918 pub async fn cleanup_expired(&self) -> Result<usize> {
920 let table = match &self.items_table {
921 Some(t) => t,
922 None => return Ok(0),
923 };
924
925 let now = Utc::now().timestamp();
926 let filter = format!("expires_at IS NOT NULL AND expires_at < {}", now);
928
929 let count = table.count_rows(Some(filter.clone())).await.unwrap_or(0);
931
932 if count > 0 {
933 if let Ok(expired_ids) = self.get_expired_item_ids(now).await
935 && let Some(ref chunks_table) = self.chunks_table
936 {
937 for item_id in &expired_ids {
938 let chunk_filter = format!("item_id = '{}'", sanitize_sql_string(item_id));
939 if let Err(e) = chunks_table.delete(&chunk_filter).await {
940 tracing::warn!(
941 "Failed to delete chunks for expired item {}: {}",
942 item_id,
943 e
944 );
945 }
946 }
947 }
948
949 table
950 .delete(&filter)
951 .await
952 .map_err(|e| SedimentError::Database(format!("Expired cleanup failed: {}", e)))?;
953
954 info!("Cleaned up {} expired items and their chunks", count);
955 }
956
957 Ok(count)
958 }
959
960 async fn get_expired_item_ids(&self, now_ts: i64) -> Result<Vec<String>> {
962 let table = match &self.items_table {
963 Some(t) => t,
964 None => return Ok(vec![]),
965 };
966
967 let filter = format!("expires_at IS NOT NULL AND expires_at < {}", now_ts);
968 let results = table
969 .query()
970 .only_if(filter)
971 .select(lancedb::query::Select::Columns(vec!["id".to_string()]))
972 .execute()
973 .await
974 .map_err(|e| SedimentError::Database(format!("Query expired IDs failed: {}", e)))?;
975
976 let batches = results
977 .try_collect::<Vec<_>>()
978 .await
979 .map_err(|e| SedimentError::Database(format!("Collect expired IDs failed: {}", e)))?;
980
981 let mut ids = Vec::new();
982 for batch in &batches {
983 if let Some(id_col) = batch.column_by_name("id") {
984 let id_array = match id_col.as_any().downcast_ref::<StringArray>() {
985 Some(arr) => arr,
986 None => continue, };
988 for i in 0..id_array.len() {
989 if !id_array.is_null(i) {
990 ids.push(id_array.value(i).to_string());
991 }
992 }
993 }
994 }
995
996 Ok(ids)
997 }
998}
999
1000pub fn score_with_decay(
1011 similarity: f32,
1012 now: i64,
1013 created_at: i64,
1014 access_count: u32,
1015 last_accessed_at: Option<i64>,
1016) -> f32 {
1017 if !similarity.is_finite() {
1019 return 0.0;
1020 }
1021
1022 let reference_time = last_accessed_at.unwrap_or(created_at);
1023 let age_secs = (now - reference_time).max(0) as f64;
1024 let age_days = age_secs / 86400.0;
1025
1026 let freshness = 1.0 / (1.0 + age_days / 30.0);
1027 let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
1028
1029 let result = similarity * (freshness * frequency) as f32;
1030 if result.is_finite() { result } else { 0.0 }
1031}
1032
1033fn detect_content_type(content: &str) -> ContentType {
1037 let trimmed = content.trim();
1038
1039 if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
1041 || (trimmed.starts_with('[') && trimmed.ends_with(']')))
1042 && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
1043 {
1044 return ContentType::Json;
1045 }
1046
1047 if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
1051 let lines: Vec<&str> = trimmed.lines().take(10).collect();
1052 let yaml_key_count = lines
1053 .iter()
1054 .filter(|line| {
1055 let l = line.trim();
1056 !l.is_empty()
1059 && !l.starts_with('#')
1060 && !l.contains("://")
1061 && l.contains(": ")
1062 && l.split(": ").next().is_some_and(|key| {
1063 let k = key.trim_start_matches("- ");
1064 !k.is_empty()
1065 && k.chars()
1066 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1067 })
1068 })
1069 .count();
1070 if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1072 return ContentType::Yaml;
1073 }
1074 }
1075
1076 if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1078 return ContentType::Markdown;
1079 }
1080
1081 let code_patterns = [
1084 "fn ",
1085 "pub fn ",
1086 "def ",
1087 "class ",
1088 "function ",
1089 "const ",
1090 "let ",
1091 "var ",
1092 "import ",
1093 "export ",
1094 "struct ",
1095 "impl ",
1096 "trait ",
1097 ];
1098 let has_code_pattern = trimmed.lines().any(|line| {
1099 let l = line.trim();
1100 code_patterns.iter().any(|p| l.starts_with(p))
1101 });
1102 if has_code_pattern {
1103 return ContentType::Code;
1104 }
1105
1106 ContentType::Text
1107}
1108
1109fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1112 let schema = Arc::new(item_schema());
1113
1114 let id = StringArray::from(vec![item.id.as_str()]);
1115 let content = StringArray::from(vec![item.content.as_str()]);
1116 let title = StringArray::from(vec![item.title.as_deref()]);
1117 let tags = StringArray::from(vec![serde_json::to_string(&item.tags).ok()]);
1118 let source = StringArray::from(vec![item.source.as_deref()]);
1119 let metadata = StringArray::from(vec![item.metadata.as_ref().map(|m| m.to_string())]);
1120 let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1121 let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1122 let expires_at = Int64Array::from(vec![item.expires_at.map(|t| t.timestamp())]);
1123 let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1124
1125 let vector = create_embedding_array(&item.embedding)?;
1126
1127 RecordBatch::try_new(
1128 schema,
1129 vec![
1130 Arc::new(id),
1131 Arc::new(content),
1132 Arc::new(title),
1133 Arc::new(tags),
1134 Arc::new(source),
1135 Arc::new(metadata),
1136 Arc::new(project_id),
1137 Arc::new(is_chunked),
1138 Arc::new(expires_at),
1139 Arc::new(created_at),
1140 Arc::new(vector),
1141 ],
1142 )
1143 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1144}
1145
1146fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1147 let mut items = Vec::new();
1148
1149 let id_col = batch
1150 .column_by_name("id")
1151 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1152 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1153
1154 let content_col = batch
1155 .column_by_name("content")
1156 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1157 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1158
1159 let title_col = batch
1160 .column_by_name("title")
1161 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1162
1163 let tags_col = batch
1164 .column_by_name("tags")
1165 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1166
1167 let source_col = batch
1168 .column_by_name("source")
1169 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1170
1171 let metadata_col = batch
1172 .column_by_name("metadata")
1173 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1174
1175 let project_id_col = batch
1176 .column_by_name("project_id")
1177 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1178
1179 let is_chunked_col = batch
1180 .column_by_name("is_chunked")
1181 .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1182
1183 let expires_at_col = batch
1184 .column_by_name("expires_at")
1185 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1186
1187 let created_at_col = batch
1188 .column_by_name("created_at")
1189 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1190
1191 let vector_col = batch
1192 .column_by_name("vector")
1193 .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1194
1195 for i in 0..batch.num_rows() {
1196 let id = id_col.value(i).to_string();
1197 let content = content_col.value(i).to_string();
1198
1199 let title = title_col.and_then(|c| {
1200 if c.is_null(i) {
1201 None
1202 } else {
1203 Some(c.value(i).to_string())
1204 }
1205 });
1206
1207 let tags: Vec<String> = tags_col
1208 .and_then(|c| {
1209 if c.is_null(i) {
1210 None
1211 } else {
1212 serde_json::from_str(c.value(i)).ok()
1213 }
1214 })
1215 .unwrap_or_default();
1216
1217 let source = source_col.and_then(|c| {
1218 if c.is_null(i) {
1219 None
1220 } else {
1221 Some(c.value(i).to_string())
1222 }
1223 });
1224
1225 let metadata = metadata_col.and_then(|c| {
1226 if c.is_null(i) {
1227 None
1228 } else {
1229 serde_json::from_str(c.value(i)).ok()
1230 }
1231 });
1232
1233 let project_id = project_id_col.and_then(|c| {
1234 if c.is_null(i) {
1235 None
1236 } else {
1237 Some(c.value(i).to_string())
1238 }
1239 });
1240
1241 let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1242
1243 let expires_at = expires_at_col.and_then(|c| {
1244 if c.is_null(i) {
1245 None
1246 } else {
1247 Some(
1248 Utc.timestamp_opt(c.value(i), 0)
1249 .single()
1250 .unwrap_or_else(Utc::now),
1251 )
1252 }
1253 });
1254
1255 let created_at = created_at_col
1256 .map(|c| {
1257 Utc.timestamp_opt(c.value(i), 0)
1258 .single()
1259 .unwrap_or_else(Utc::now)
1260 })
1261 .unwrap_or_else(Utc::now);
1262
1263 let embedding = vector_col
1264 .and_then(|col| {
1265 let value = col.value(i);
1266 value
1267 .as_any()
1268 .downcast_ref::<Float32Array>()
1269 .map(|arr| arr.values().to_vec())
1270 })
1271 .unwrap_or_default();
1272
1273 let item = Item {
1274 id,
1275 content,
1276 embedding,
1277 title,
1278 tags,
1279 source,
1280 metadata,
1281 project_id,
1282 is_chunked,
1283 expires_at,
1284 created_at,
1285 };
1286
1287 items.push(item);
1288 }
1289
1290 Ok(items)
1291}
1292
1293fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1294 let schema = Arc::new(chunk_schema());
1295
1296 let id = StringArray::from(vec![chunk.id.as_str()]);
1297 let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1298 let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1299 let content = StringArray::from(vec![chunk.content.as_str()]);
1300 let context = StringArray::from(vec![chunk.context.as_deref()]);
1301
1302 let vector = create_embedding_array(&chunk.embedding)?;
1303
1304 RecordBatch::try_new(
1305 schema,
1306 vec![
1307 Arc::new(id),
1308 Arc::new(item_id),
1309 Arc::new(chunk_index),
1310 Arc::new(content),
1311 Arc::new(context),
1312 Arc::new(vector),
1313 ],
1314 )
1315 .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1316}
1317
1318fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1319 let mut chunks = Vec::new();
1320
1321 let id_col = batch
1322 .column_by_name("id")
1323 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1324 .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1325
1326 let item_id_col = batch
1327 .column_by_name("item_id")
1328 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1329 .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1330
1331 let chunk_index_col = batch
1332 .column_by_name("chunk_index")
1333 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1334 .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1335
1336 let content_col = batch
1337 .column_by_name("content")
1338 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1339 .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1340
1341 let context_col = batch
1342 .column_by_name("context")
1343 .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1344
1345 for i in 0..batch.num_rows() {
1346 let id = id_col.value(i).to_string();
1347 let item_id = item_id_col.value(i).to_string();
1348 let chunk_index = chunk_index_col.value(i) as usize;
1349 let content = content_col.value(i).to_string();
1350 let context = context_col.and_then(|c| {
1351 if c.is_null(i) {
1352 None
1353 } else {
1354 Some(c.value(i).to_string())
1355 }
1356 });
1357
1358 let chunk = Chunk {
1359 id,
1360 item_id,
1361 chunk_index,
1362 content,
1363 embedding: Vec::new(),
1364 context,
1365 };
1366
1367 chunks.push(chunk);
1368 }
1369
1370 Ok(chunks)
1371}
1372
1373fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1374 let values = Float32Array::from(embedding.to_vec());
1375 let field = Arc::new(Field::new("item", DataType::Float32, true));
1376
1377 FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1378 .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1379}
1380
1381#[cfg(test)]
1382mod tests {
1383 use super::*;
1384
1385 #[test]
1386 fn test_score_with_decay_fresh_item() {
1387 let now = 1700000000i64;
1388 let created = now; let score = score_with_decay(0.8, now, created, 0, None);
1390 let expected = 0.8 * 1.0 * 1.0;
1392 assert!((score - expected).abs() < 0.001, "got {}", score);
1393 }
1394
1395 #[test]
1396 fn test_score_with_decay_30_day_old() {
1397 let now = 1700000000i64;
1398 let created = now - 30 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1400 let expected = 0.8 * 0.5;
1402 assert!((score - expected).abs() < 0.001, "got {}", score);
1403 }
1404
1405 #[test]
1406 fn test_score_with_decay_frequent_access() {
1407 let now = 1700000000i64;
1408 let created = now - 30 * 86400;
1409 let last_accessed = now; let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1411 let freq = 1.0 + 0.1 * (11.0_f64).ln();
1413 let expected = 0.8 * 1.0 * freq as f32;
1414 assert!((score - expected).abs() < 0.01, "got {}", score);
1415 }
1416
1417 #[test]
1418 fn test_score_with_decay_old_and_unused() {
1419 let now = 1700000000i64;
1420 let created = now - 90 * 86400; let score = score_with_decay(0.8, now, created, 0, None);
1422 let expected = 0.8 * 0.25;
1424 assert!((score - expected).abs() < 0.001, "got {}", score);
1425 }
1426
1427 #[test]
1428 fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1429 assert_eq!(sanitize_sql_string("hello"), "hello");
1430 assert_eq!(sanitize_sql_string("it's"), "it''s");
1431 assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1432 assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1433 }
1434
1435 #[test]
1436 fn test_sanitize_sql_string_strips_null_bytes() {
1437 assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1438 assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1439 assert_eq!(sanitize_sql_string("clean"), "clean");
1440 }
1441
1442 #[test]
1443 fn test_sanitize_sql_string_strips_semicolons() {
1444 assert_eq!(
1445 sanitize_sql_string("a; DROP TABLE items"),
1446 "a DROP TABLE items"
1447 );
1448 assert_eq!(sanitize_sql_string("normal;"), "normal");
1449 }
1450
1451 #[test]
1452 fn test_sanitize_sql_string_strips_comments() {
1453 assert_eq!(sanitize_sql_string("val' -- comment"), "val'' comment");
1455 assert_eq!(sanitize_sql_string("val' /* block */"), "val'' block */");
1457 assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1459 }
1460
1461 #[test]
1462 fn test_sanitize_sql_string_adversarial_inputs() {
1463 assert_eq!(
1465 sanitize_sql_string("'; DROP TABLE items;--"),
1466 "'' DROP TABLE items"
1467 );
1468 assert_eq!(
1470 sanitize_sql_string("hello\u{200B}world"),
1471 "hello\u{200B}world"
1472 );
1473 assert_eq!(sanitize_sql_string(""), "");
1475 assert_eq!(sanitize_sql_string("\0;\0"), "");
1477 }
1478
1479 #[test]
1480 fn test_is_valid_id() {
1481 assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1483 assert!(is_valid_id("abcdef0123456789"));
1484 assert!(!is_valid_id(""));
1486 assert!(!is_valid_id("'; DROP TABLE items;--"));
1487 assert!(!is_valid_id("hello world"));
1488 assert!(!is_valid_id("abc\0def"));
1489 assert!(!is_valid_id(&"a".repeat(65)));
1491 }
1492
1493 #[test]
1494 fn test_detect_content_type_yaml_not_prose() {
1495 let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1497 let detected = detect_content_type(prose);
1498 assert_ne!(
1499 detected,
1500 ContentType::Yaml,
1501 "Prose with colons should not be detected as YAML"
1502 );
1503
1504 let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1506 let detected = detect_content_type(yaml);
1507 assert_eq!(detected, ContentType::Yaml);
1508 }
1509
1510 #[test]
1511 fn test_detect_content_type_yaml_with_separator() {
1512 let yaml = "---\nname: test\nversion: 1.0";
1513 let detected = detect_content_type(yaml);
1514 assert_eq!(detected, ContentType::Yaml);
1515 }
1516
1517 #[test]
1518 fn test_chunk_threshold_uses_chars_not_bytes() {
1519 let emoji_content = "😀".repeat(500);
1522 assert_eq!(emoji_content.chars().count(), 500);
1523 assert_eq!(emoji_content.len(), 2000); let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1526 assert!(
1527 !should_chunk,
1528 "500 chars should not exceed 1000-char threshold"
1529 );
1530
1531 let long_content = "a".repeat(1001);
1533 let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1534 assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1535 }
1536}