Skip to main content

sediment/
db.rs

1//! Database module using LanceDB for vector storage
2//!
3//! Provides a simple interface for storing and searching items
4//! using LanceDB's native vector search capabilities.
5
6use std::path::PathBuf;
7use std::sync::Arc;
8
9/// Sanitize a string value for use in LanceDB SQL filter expressions.
10///
11/// LanceDB uses DataFusion as its SQL engine. Since `only_if()` doesn't support
12/// parameterized queries, we must escape string literals. This function handles:
13/// - Null bytes: stripped (could truncate strings in some parsers)
14/// - Backslashes: escaped to prevent escape sequence injection
15/// - Single quotes: doubled per SQL standard
16/// - Semicolons: stripped to prevent statement injection
17/// - Comment sequences: `--` and `/*` stripped to prevent comment injection
18fn sanitize_sql_string(s: &str) -> String {
19    s.replace('\0', "")
20        .replace('\\', "\\\\")
21        .replace('\'', "''")
22        .replace(';', "")
23        .replace("--", "")
24        .replace("/*", "")
25}
26
27/// Validate that a string looks like a valid item/project ID (UUID hex + hyphens).
28/// Returns true if the string only contains safe characters for SQL interpolation.
29/// Use this as an additional guard before `sanitize_sql_string` for ID fields.
30fn is_valid_id(s: &str) -> bool {
31    !s.is_empty() && s.len() <= 64 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
32}
33
34use arrow_array::{
35    Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch,
36    RecordBatchIterator, StringArray,
37};
38use arrow_schema::{DataType, Field, Schema};
39use chrono::{TimeZone, Utc};
40use futures::TryStreamExt;
41use lancedb::Table;
42use lancedb::connect;
43use lancedb::query::{ExecutableQuery, QueryBase};
44use tracing::{debug, info};
45
46use crate::boost_similarity;
47use crate::chunker::{ChunkingConfig, chunk_content};
48use crate::document::ContentType;
49use crate::embedder::{EMBEDDING_DIM, Embedder};
50use crate::error::{Result, SedimentError};
51use crate::item::{Chunk, ConflictInfo, Item, ItemFilters, SearchResult, StoreResult};
52
53/// Threshold for auto-chunking (in characters)
54const CHUNK_THRESHOLD: usize = 1000;
55
56/// Similarity threshold for conflict detection
57const CONFLICT_SIMILARITY_THRESHOLD: f32 = 0.85;
58
59/// Maximum number of conflicts to return
60const CONFLICT_SEARCH_LIMIT: usize = 5;
61
62/// Maximum number of chunks per item to prevent CPU exhaustion during embedding.
63/// With default config (800 char chunks), 1MB content produces ~1250 chunks.
64/// Cap at 200 to bound embedding time while covering most legitimate content.
65const MAX_CHUNKS_PER_ITEM: usize = 200;
66
67/// Database wrapper for LanceDB
68pub struct Database {
69    db: lancedb::Connection,
70    embedder: Arc<Embedder>,
71    project_id: Option<String>,
72    items_table: Option<Table>,
73    chunks_table: Option<Table>,
74}
75
76/// Database statistics
77#[derive(Debug, Default, Clone)]
78pub struct DatabaseStats {
79    pub item_count: usize,
80    pub chunk_count: usize,
81}
82
83/// Current schema version. Increment when making breaking schema changes.
84const SCHEMA_VERSION: i32 = 2;
85
86// Arrow schema builders
87fn item_schema() -> Schema {
88    Schema::new(vec![
89        Field::new("id", DataType::Utf8, false),
90        Field::new("content", DataType::Utf8, false),
91        Field::new("project_id", DataType::Utf8, true),
92        Field::new("is_chunked", DataType::Boolean, false),
93        Field::new("created_at", DataType::Int64, false), // Unix timestamp
94        Field::new(
95            "vector",
96            DataType::FixedSizeList(
97                Arc::new(Field::new("item", DataType::Float32, true)),
98                EMBEDDING_DIM as i32,
99            ),
100            false,
101        ),
102    ])
103}
104
105fn chunk_schema() -> Schema {
106    Schema::new(vec![
107        Field::new("id", DataType::Utf8, false),
108        Field::new("item_id", DataType::Utf8, false),
109        Field::new("chunk_index", DataType::Int32, false),
110        Field::new("content", DataType::Utf8, false),
111        Field::new("context", DataType::Utf8, true),
112        Field::new(
113            "vector",
114            DataType::FixedSizeList(
115                Arc::new(Field::new("item", DataType::Float32, true)),
116                EMBEDDING_DIM as i32,
117            ),
118            false,
119        ),
120    ])
121}
122
123impl Database {
124    /// Open or create a database at the given path
125    pub async fn open(path: impl Into<PathBuf>) -> Result<Self> {
126        Self::open_with_project(path, None).await
127    }
128
129    /// Open or create a database at the given path with a project ID
130    pub async fn open_with_project(
131        path: impl Into<PathBuf>,
132        project_id: Option<String>,
133    ) -> Result<Self> {
134        let embedder = Arc::new(Embedder::new()?);
135        Self::open_with_embedder(path, project_id, embedder).await
136    }
137
138    /// Open or create a database with a pre-existing embedder.
139    ///
140    /// This constructor is useful for connection pooling scenarios where
141    /// the expensive embedder should be loaded once and shared across
142    /// multiple database connections.
143    ///
144    /// # Arguments
145    ///
146    /// * `path` - Path to the database directory
147    /// * `project_id` - Optional project ID for scoped operations
148    /// * `embedder` - Shared embedder instance
149    pub async fn open_with_embedder(
150        path: impl Into<PathBuf>,
151        project_id: Option<String>,
152        embedder: Arc<Embedder>,
153    ) -> Result<Self> {
154        let path = path.into();
155        info!("Opening database at {:?}", path);
156
157        // Ensure parent directory exists
158        if let Some(parent) = path.parent() {
159            std::fs::create_dir_all(parent).map_err(|e| {
160                SedimentError::Database(format!("Failed to create database directory: {}", e))
161            })?;
162        }
163
164        let db = connect(path.to_str().ok_or_else(|| {
165            SedimentError::Database("Database path contains invalid UTF-8".to_string())
166        })?)
167        .execute()
168        .await
169        .map_err(|e| SedimentError::Database(format!("Failed to connect to database: {}", e)))?;
170
171        let mut database = Self {
172            db,
173            embedder,
174            project_id,
175            items_table: None,
176            chunks_table: None,
177        };
178
179        database.ensure_tables().await?;
180        database.ensure_vector_index().await?;
181
182        Ok(database)
183    }
184
185    /// Set the current project ID for scoped operations
186    pub fn set_project_id(&mut self, project_id: Option<String>) {
187        self.project_id = project_id;
188    }
189
190    /// Get the current project ID
191    pub fn project_id(&self) -> Option<&str> {
192        self.project_id.as_deref()
193    }
194
195    /// Ensure all required tables exist, migrating schema if needed
196    async fn ensure_tables(&mut self) -> Result<()> {
197        // Check for existing tables
198        let table_names = self
199            .db
200            .table_names()
201            .execute()
202            .await
203            .map_err(|e| SedimentError::Database(format!("Failed to list tables: {}", e)))?;
204
205        // Check if migration is needed (items table exists but has old schema)
206        if table_names.contains(&"items".to_string()) {
207            let needs_migration = self.check_needs_migration().await?;
208            if needs_migration {
209                info!("Migrating database schema to version {}", SCHEMA_VERSION);
210                self.migrate_schema().await?;
211            }
212        }
213
214        // Items table
215        if table_names.contains(&"items".to_string()) {
216            self.items_table =
217                Some(self.db.open_table("items").execute().await.map_err(|e| {
218                    SedimentError::Database(format!("Failed to open items: {}", e))
219                })?);
220        }
221
222        // Chunks table
223        if table_names.contains(&"chunks".to_string()) {
224            self.chunks_table =
225                Some(self.db.open_table("chunks").execute().await.map_err(|e| {
226                    SedimentError::Database(format!("Failed to open chunks: {}", e))
227                })?);
228        }
229
230        Ok(())
231    }
232
233    /// Check if the database needs migration by checking for old schema columns
234    async fn check_needs_migration(&self) -> Result<bool> {
235        let table = self.db.open_table("items").execute().await.map_err(|e| {
236            SedimentError::Database(format!("Failed to open items for check: {}", e))
237        })?;
238
239        let schema = table
240            .schema()
241            .await
242            .map_err(|e| SedimentError::Database(format!("Failed to get schema: {}", e)))?;
243
244        // Old schema has 'tags' column, new schema doesn't
245        let has_tags = schema.fields().iter().any(|f| f.name() == "tags");
246        Ok(has_tags)
247    }
248
249    /// Migrate from old schema to new schema
250    async fn migrate_schema(&mut self) -> Result<()> {
251        info!("Starting schema migration...");
252
253        // Open old table
254        let old_table = self
255            .db
256            .open_table("items")
257            .execute()
258            .await
259            .map_err(|e| SedimentError::Database(format!("Failed to open old items: {}", e)))?;
260
261        // Read all items from old table
262        let results = old_table
263            .query()
264            .execute()
265            .await
266            .map_err(|e| SedimentError::Database(format!("Migration query failed: {}", e)))?
267            .try_collect::<Vec<_>>()
268            .await
269            .map_err(|e| SedimentError::Database(format!("Migration collect failed: {}", e)))?;
270
271        // Convert old items to new format
272        let mut new_batches = Vec::new();
273        for batch in &results {
274            let converted = self.convert_batch_to_new_schema(batch)?;
275            new_batches.push(converted);
276        }
277
278        let item_count: usize = new_batches.iter().map(|b| b.num_rows()).sum();
279        info!("Migrating {} items to new schema", item_count);
280
281        // Drop old table
282        self.db.drop_table("items").await.map_err(|e| {
283            SedimentError::Database(format!("Failed to drop old items table: {}", e))
284        })?;
285
286        // Create new table with new schema
287        let schema = Arc::new(item_schema());
288        let new_table = self
289            .db
290            .create_empty_table("items", schema.clone())
291            .execute()
292            .await
293            .map_err(|e| {
294                SedimentError::Database(format!("Failed to create new items table: {}", e))
295            })?;
296
297        // Insert migrated data
298        if !new_batches.is_empty() {
299            let batches = RecordBatchIterator::new(new_batches.into_iter().map(Ok), schema);
300            new_table
301                .add(Box::new(batches))
302                .execute()
303                .await
304                .map_err(|e| {
305                    SedimentError::Database(format!("Failed to insert migrated items: {}", e))
306                })?;
307        }
308
309        info!("Schema migration completed successfully");
310        Ok(())
311    }
312
313    /// Convert a batch from old schema to new schema
314    fn convert_batch_to_new_schema(&self, batch: &RecordBatch) -> Result<RecordBatch> {
315        let schema = Arc::new(item_schema());
316
317        // Extract columns from old batch (handle missing columns gracefully)
318        let id_col = batch
319            .column_by_name("id")
320            .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?
321            .clone();
322
323        let content_col = batch
324            .column_by_name("content")
325            .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?
326            .clone();
327
328        let project_id_col = batch
329            .column_by_name("project_id")
330            .ok_or_else(|| SedimentError::Database("Missing project_id column".to_string()))?
331            .clone();
332
333        let is_chunked_col = batch
334            .column_by_name("is_chunked")
335            .ok_or_else(|| SedimentError::Database("Missing is_chunked column".to_string()))?
336            .clone();
337
338        let created_at_col = batch
339            .column_by_name("created_at")
340            .ok_or_else(|| SedimentError::Database("Missing created_at column".to_string()))?
341            .clone();
342
343        let vector_col = batch
344            .column_by_name("vector")
345            .ok_or_else(|| SedimentError::Database("Missing vector column".to_string()))?
346            .clone();
347
348        RecordBatch::try_new(
349            schema,
350            vec![
351                id_col,
352                content_col,
353                project_id_col,
354                is_chunked_col,
355                created_at_col,
356                vector_col,
357            ],
358        )
359        .map_err(|e| SedimentError::Database(format!("Failed to create migrated batch: {}", e)))
360    }
361
362    /// Ensure vector indexes exist on tables with enough rows.
363    ///
364    /// LanceDB requires at least 256 rows before creating an index.
365    /// Once created, the index converts brute-force scans to HNSW/IVF-PQ.
366    async fn ensure_vector_index(&self) -> Result<()> {
367        const MIN_ROWS_FOR_INDEX: usize = 256;
368
369        for (name, table_opt) in [("items", &self.items_table), ("chunks", &self.chunks_table)] {
370            if let Some(table) = table_opt {
371                let row_count = table.count_rows(None).await.unwrap_or(0);
372                if row_count < MIN_ROWS_FOR_INDEX {
373                    continue;
374                }
375
376                // Check if index already exists by listing indices
377                let indices = table.list_indices().await.unwrap_or_default();
378
379                let has_vector_index = indices
380                    .iter()
381                    .any(|idx| idx.columns.contains(&"vector".to_string()));
382
383                if !has_vector_index {
384                    info!(
385                        "Creating vector index on {} table ({} rows)",
386                        name, row_count
387                    );
388                    match table
389                        .create_index(&["vector"], lancedb::index::Index::Auto)
390                        .execute()
391                        .await
392                    {
393                        Ok(_) => info!("Vector index created on {} table", name),
394                        Err(e) => {
395                            // Non-fatal: brute-force search still works
396                            tracing::warn!("Failed to create vector index on {}: {}", name, e);
397                        }
398                    }
399                }
400            }
401        }
402
403        Ok(())
404    }
405
406    /// Get or create the items table
407    async fn get_items_table(&mut self) -> Result<&Table> {
408        if self.items_table.is_none() {
409            let schema = Arc::new(item_schema());
410            let table = self
411                .db
412                .create_empty_table("items", schema)
413                .execute()
414                .await
415                .map_err(|e| {
416                    SedimentError::Database(format!("Failed to create items table: {}", e))
417                })?;
418            self.items_table = Some(table);
419        }
420        Ok(self.items_table.as_ref().unwrap())
421    }
422
423    /// Get or create the chunks table
424    async fn get_chunks_table(&mut self) -> Result<&Table> {
425        if self.chunks_table.is_none() {
426            let schema = Arc::new(chunk_schema());
427            let table = self
428                .db
429                .create_empty_table("chunks", schema)
430                .execute()
431                .await
432                .map_err(|e| {
433                    SedimentError::Database(format!("Failed to create chunks table: {}", e))
434                })?;
435            self.chunks_table = Some(table);
436        }
437        Ok(self.chunks_table.as_ref().unwrap())
438    }
439
440    // ==================== Item Operations ====================
441
442    /// Store an item with automatic chunking for long content
443    ///
444    /// Returns a `StoreResult` containing the new item ID and any potential conflicts
445    /// (items with similarity >= 0.85 to the new content).
446    pub async fn store_item(&mut self, mut item: Item) -> Result<StoreResult> {
447        // Set project_id if not already set and we have a current project
448        if item.project_id.is_none() {
449            item.project_id = self.project_id.clone();
450        }
451
452        // Determine if we need to chunk (by character count, not byte count,
453        // so multi-byte UTF-8 content isn't prematurely chunked)
454        let should_chunk = item.content.chars().count() > CHUNK_THRESHOLD;
455        item.is_chunked = should_chunk;
456
457        // Generate item embedding
458        let embedding_text = item.embedding_text();
459        let embedding = self.embedder.embed(&embedding_text)?;
460        item.embedding = embedding;
461
462        // Store the item
463        let table = self.get_items_table().await?;
464        let batch = item_to_batch(&item)?;
465        let batches = RecordBatchIterator::new(vec![Ok(batch)], Arc::new(item_schema()));
466
467        table
468            .add(Box::new(batches))
469            .execute()
470            .await
471            .map_err(|e| SedimentError::Database(format!("Failed to store item: {}", e)))?;
472
473        // If chunking is needed, create and store chunks
474        if should_chunk {
475            let embedder = self.embedder.clone();
476            let chunks_table = self.get_chunks_table().await?;
477
478            // Detect content type for smart chunking
479            let content_type = detect_content_type(&item.content);
480            let config = ChunkingConfig::default();
481            let mut chunk_results = chunk_content(&item.content, content_type, &config);
482
483            // Cap chunk count to prevent CPU exhaustion from pathological inputs
484            if chunk_results.len() > MAX_CHUNKS_PER_ITEM {
485                tracing::warn!(
486                    "Chunk count {} exceeds limit {}, truncating",
487                    chunk_results.len(),
488                    MAX_CHUNKS_PER_ITEM
489                );
490                chunk_results.truncate(MAX_CHUNKS_PER_ITEM);
491            }
492
493            for (i, chunk_result) in chunk_results.iter().enumerate() {
494                let mut chunk = Chunk::new(&item.id, i, &chunk_result.content);
495
496                if let Some(ctx) = &chunk_result.context {
497                    chunk = chunk.with_context(ctx);
498                }
499
500                let chunk_embedding = embedder.embed(&chunk.content)?;
501                chunk.embedding = chunk_embedding;
502
503                let chunk_batch = chunk_to_batch(&chunk)?;
504                let batches =
505                    RecordBatchIterator::new(vec![Ok(chunk_batch)], Arc::new(chunk_schema()));
506
507                chunks_table
508                    .add(Box::new(batches))
509                    .execute()
510                    .await
511                    .map_err(|e| {
512                        SedimentError::Database(format!("Failed to store chunk: {}", e))
513                    })?;
514            }
515
516            debug!(
517                "Stored item: {} with {} chunks",
518                item.id,
519                chunk_results.len()
520            );
521        } else {
522            debug!("Stored item: {} (no chunking)", item.id);
523        }
524
525        // Detect conflicts after storing (informational only, avoids TOCTOU race)
526        let potential_conflicts = self
527            .find_similar_items(
528                &item.content,
529                CONFLICT_SIMILARITY_THRESHOLD,
530                CONFLICT_SEARCH_LIMIT,
531            )
532            .await
533            .unwrap_or_default()
534            .into_iter()
535            .filter(|c| c.id != item.id)
536            .collect();
537
538        Ok(StoreResult {
539            id: item.id,
540            potential_conflicts,
541        })
542    }
543
544    /// Search items by semantic similarity
545    pub async fn search_items(
546        &mut self,
547        query: &str,
548        limit: usize,
549        filters: ItemFilters,
550    ) -> Result<Vec<SearchResult>> {
551        // Cap limit to prevent overflow in limit*2 and limit*3 multiplications below
552        let limit = limit.min(1000);
553        // Retry vector index creation if it failed previously
554        self.ensure_vector_index().await?;
555
556        // Generate query embedding
557        let query_embedding = self.embedder.embed(query)?;
558        let min_similarity = filters.min_similarity.unwrap_or(0.3);
559
560        // We need to search both items and chunks, then merge results
561        let mut results_map: std::collections::HashMap<String, (SearchResult, f32)> =
562            std::collections::HashMap::new();
563
564        // Search items table directly (for non-chunked items and chunked items)
565        if let Some(table) = &self.items_table {
566            let query_builder = table
567                .vector_search(query_embedding.clone())
568                .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
569                .limit(limit * 2);
570
571            let results = query_builder
572                .execute()
573                .await
574                .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
575                .try_collect::<Vec<_>>()
576                .await
577                .map_err(|e| {
578                    SedimentError::Database(format!("Failed to collect results: {}", e))
579                })?;
580
581            for batch in results {
582                let items = batch_to_items(&batch)?;
583                let distances = batch
584                    .column_by_name("_distance")
585                    .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
586
587                for (i, item) in items.into_iter().enumerate() {
588                    let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
589                    let similarity = 1.0 / (1.0 + distance);
590
591                    if similarity < min_similarity {
592                        continue;
593                    }
594
595                    // Apply project boosting
596                    let boosted_similarity = boost_similarity(
597                        similarity,
598                        item.project_id.as_deref(),
599                        self.project_id.as_deref(),
600                    );
601
602                    let result = SearchResult::from_item(&item, boosted_similarity);
603                    results_map
604                        .entry(item.id.clone())
605                        .or_insert((result, boosted_similarity));
606                }
607            }
608        }
609
610        // Search chunks table (for chunked items)
611        if let Some(chunks_table) = &self.chunks_table {
612            let chunk_results = chunks_table
613                .vector_search(query_embedding)
614                .map_err(|e| {
615                    SedimentError::Database(format!("Failed to build chunk search: {}", e))
616                })?
617                .limit(limit * 3)
618                .execute()
619                .await
620                .map_err(|e| SedimentError::Database(format!("Chunk search failed: {}", e)))?
621                .try_collect::<Vec<_>>()
622                .await
623                .map_err(|e| {
624                    SedimentError::Database(format!("Failed to collect chunk results: {}", e))
625                })?;
626
627            // Group chunks by item and find best chunk for each item
628            let mut chunk_matches: std::collections::HashMap<String, (String, f32)> =
629                std::collections::HashMap::new();
630
631            for batch in chunk_results {
632                let chunks = batch_to_chunks(&batch)?;
633                let distances = batch
634                    .column_by_name("_distance")
635                    .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
636
637                for (i, chunk) in chunks.into_iter().enumerate() {
638                    let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
639                    let similarity = 1.0 / (1.0 + distance);
640
641                    if similarity < min_similarity {
642                        continue;
643                    }
644
645                    // Keep track of best matching chunk per item
646                    chunk_matches
647                        .entry(chunk.item_id.clone())
648                        .and_modify(|(content, best_sim)| {
649                            if similarity > *best_sim {
650                                *content = chunk.content.clone();
651                                *best_sim = similarity;
652                            }
653                        })
654                        .or_insert((chunk.content.clone(), similarity));
655                }
656            }
657
658            // Fetch parent items for chunk matches
659            for (item_id, (excerpt, chunk_similarity)) in chunk_matches {
660                if let Some(item) = self.get_item(&item_id).await? {
661                    // Apply project boosting
662                    let boosted_similarity = boost_similarity(
663                        chunk_similarity,
664                        item.project_id.as_deref(),
665                        self.project_id.as_deref(),
666                    );
667
668                    let result =
669                        SearchResult::from_item_with_excerpt(&item, boosted_similarity, excerpt);
670
671                    // Update if this chunk-based result is better
672                    results_map
673                        .entry(item_id)
674                        .and_modify(|(existing, existing_sim)| {
675                            if boosted_similarity > *existing_sim {
676                                *existing = result.clone();
677                                *existing_sim = boosted_similarity;
678                            }
679                        })
680                        .or_insert((result, boosted_similarity));
681                }
682            }
683        }
684
685        // Convert map to sorted vec
686        let mut search_results: Vec<SearchResult> =
687            results_map.into_values().map(|(r, _)| r).collect();
688        search_results.sort_by(|a, b| {
689            b.similarity
690                .partial_cmp(&a.similarity)
691                .unwrap_or(std::cmp::Ordering::Equal)
692        });
693        search_results.truncate(limit);
694
695        Ok(search_results)
696    }
697
698    /// Find items similar to the given content (for conflict detection)
699    ///
700    /// This searches the items table directly by content embedding to find
701    /// potentially conflicting items before storing new content.
702    pub async fn find_similar_items(
703        &mut self,
704        content: &str,
705        min_similarity: f32,
706        limit: usize,
707    ) -> Result<Vec<ConflictInfo>> {
708        // Generate embedding for the content
709        let embedding = self.embedder.embed(content)?;
710
711        let table = match &self.items_table {
712            Some(t) => t,
713            None => return Ok(Vec::new()),
714        };
715
716        let results = table
717            .vector_search(embedding)
718            .map_err(|e| SedimentError::Database(format!("Failed to build search: {}", e)))?
719            .limit(limit)
720            .execute()
721            .await
722            .map_err(|e| SedimentError::Database(format!("Search failed: {}", e)))?
723            .try_collect::<Vec<_>>()
724            .await
725            .map_err(|e| SedimentError::Database(format!("Failed to collect results: {}", e)))?;
726
727        let mut conflicts = Vec::new();
728
729        for batch in results {
730            let items = batch_to_items(&batch)?;
731            let distances = batch
732                .column_by_name("_distance")
733                .and_then(|c| c.as_any().downcast_ref::<Float32Array>());
734
735            for (i, item) in items.into_iter().enumerate() {
736                let distance = distances.map(|d| d.value(i)).unwrap_or(0.0);
737                let similarity = 1.0 / (1.0 + distance);
738
739                if similarity >= min_similarity {
740                    conflicts.push(ConflictInfo {
741                        id: item.id,
742                        content: item.content,
743                        similarity,
744                    });
745                }
746            }
747        }
748
749        // Sort by similarity descending
750        conflicts.sort_by(|a, b| {
751            b.similarity
752                .partial_cmp(&a.similarity)
753                .unwrap_or(std::cmp::Ordering::Equal)
754        });
755
756        Ok(conflicts)
757    }
758
759    /// List items with optional filters
760    pub async fn list_items(
761        &mut self,
762        _filters: ItemFilters,
763        limit: Option<usize>,
764        scope: crate::ListScope,
765    ) -> Result<Vec<Item>> {
766        let table = match &self.items_table {
767            Some(t) => t,
768            None => return Ok(Vec::new()),
769        };
770
771        let mut filter_parts = Vec::new();
772
773        // Apply scope filter
774        match scope {
775            crate::ListScope::Project => {
776                if let Some(ref pid) = self.project_id {
777                    filter_parts.push(format!("project_id = '{}'", sanitize_sql_string(pid)));
778                } else {
779                    // No project context: return empty rather than silently listing all items
780                    return Ok(Vec::new());
781                }
782            }
783            crate::ListScope::Global => {
784                filter_parts.push("project_id IS NULL".to_string());
785            }
786            crate::ListScope::All => {
787                // No additional filter
788            }
789        }
790
791        let mut query = table.query();
792
793        if !filter_parts.is_empty() {
794            let filter_str = filter_parts.join(" AND ");
795            query = query.only_if(filter_str);
796        }
797
798        if let Some(l) = limit {
799            query = query.limit(l);
800        }
801
802        let results = query
803            .execute()
804            .await
805            .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
806            .try_collect::<Vec<_>>()
807            .await
808            .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
809
810        let mut items = Vec::new();
811        for batch in results {
812            items.extend(batch_to_items(&batch)?);
813        }
814
815        Ok(items)
816    }
817
818    /// Get an item by ID
819    pub async fn get_item(&self, id: &str) -> Result<Option<Item>> {
820        if !is_valid_id(id) {
821            return Ok(None);
822        }
823        let table = match &self.items_table {
824            Some(t) => t,
825            None => return Ok(None),
826        };
827
828        let results = table
829            .query()
830            .only_if(format!("id = '{}'", sanitize_sql_string(id)))
831            .limit(1)
832            .execute()
833            .await
834            .map_err(|e| SedimentError::Database(format!("Query failed: {}", e)))?
835            .try_collect::<Vec<_>>()
836            .await
837            .map_err(|e| SedimentError::Database(format!("Failed to collect: {}", e)))?;
838
839        for batch in results {
840            let items = batch_to_items(&batch)?;
841            if let Some(item) = items.into_iter().next() {
842                return Ok(Some(item));
843            }
844        }
845
846        Ok(None)
847    }
848
849    /// Get multiple items by ID in a single query
850    pub async fn get_items_batch(&self, ids: &[&str]) -> Result<Vec<Item>> {
851        let table = match &self.items_table {
852            Some(t) => t,
853            None => return Ok(Vec::new()),
854        };
855
856        if ids.is_empty() {
857            return Ok(Vec::new());
858        }
859
860        let quoted: Vec<String> = ids
861            .iter()
862            .filter(|id| is_valid_id(id))
863            .map(|id| format!("'{}'", sanitize_sql_string(id)))
864            .collect();
865        if quoted.is_empty() {
866            return Ok(Vec::new());
867        }
868        let filter = format!("id IN ({})", quoted.join(", "));
869
870        let results = table
871            .query()
872            .only_if(filter)
873            .execute()
874            .await
875            .map_err(|e| SedimentError::Database(format!("Batch query failed: {}", e)))?
876            .try_collect::<Vec<_>>()
877            .await
878            .map_err(|e| SedimentError::Database(format!("Failed to collect batch: {}", e)))?;
879
880        let mut items = Vec::new();
881        for batch in results {
882            items.extend(batch_to_items(&batch)?);
883        }
884
885        Ok(items)
886    }
887
888    /// Delete an item and its chunks.
889    /// Returns `true` if the item existed, `false` if it was not found.
890    pub async fn delete_item(&self, id: &str) -> Result<bool> {
891        if !is_valid_id(id) {
892            return Ok(false);
893        }
894        // Check if item exists first
895        let table = match &self.items_table {
896            Some(t) => t,
897            None => return Ok(false),
898        };
899
900        let exists = self.get_item(id).await?.is_some();
901        if !exists {
902            return Ok(false);
903        }
904
905        // Delete chunks first
906        if let Some(chunks_table) = &self.chunks_table {
907            chunks_table
908                .delete(&format!("item_id = '{}'", sanitize_sql_string(id)))
909                .await
910                .map_err(|e| SedimentError::Database(format!("Delete chunks failed: {}", e)))?;
911        }
912
913        // Delete item
914        table
915            .delete(&format!("id = '{}'", sanitize_sql_string(id)))
916            .await
917            .map_err(|e| SedimentError::Database(format!("Delete failed: {}", e)))?;
918
919        Ok(true)
920    }
921
922    /// Get database statistics
923    pub async fn stats(&self) -> Result<DatabaseStats> {
924        let mut stats = DatabaseStats::default();
925
926        if let Some(table) = &self.items_table {
927            stats.item_count = table
928                .count_rows(None)
929                .await
930                .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
931        }
932
933        if let Some(table) = &self.chunks_table {
934            stats.chunk_count = table
935                .count_rows(None)
936                .await
937                .map_err(|e| SedimentError::Database(format!("Count failed: {}", e)))?;
938        }
939
940        Ok(stats)
941    }
942}
943
944// ==================== Decay Scoring ====================
945
946/// Compute a decay-adjusted score for a search result.
947///
948/// Formula: `similarity * freshness * frequency`
949/// - freshness = 1.0 / (1.0 + age_days / 30.0)  (half-life ~30 days)
950/// - frequency = 1.0 + 0.1 * ln(1 + access_count)
951///
952/// `last_accessed_at` and `created_at` are unix timestamps.
953/// If no access record exists, pass `access_count=0` and use `created_at` for age.
954pub fn score_with_decay(
955    similarity: f32,
956    now: i64,
957    created_at: i64,
958    access_count: u32,
959    last_accessed_at: Option<i64>,
960) -> f32 {
961    // Guard against NaN/Inf from corrupted data
962    if !similarity.is_finite() {
963        return 0.0;
964    }
965
966    let reference_time = last_accessed_at.unwrap_or(created_at);
967    let age_secs = (now - reference_time).max(0) as f64;
968    let age_days = age_secs / 86400.0;
969
970    let freshness = 1.0 / (1.0 + age_days / 30.0);
971    let frequency = 1.0 + 0.1 * (1.0 + access_count as f64).ln();
972
973    let result = similarity * (freshness * frequency) as f32;
974    if result.is_finite() { result } else { 0.0 }
975}
976
977// ==================== Helper Functions ====================
978
979/// Detect content type for smart chunking
980fn detect_content_type(content: &str) -> ContentType {
981    let trimmed = content.trim();
982
983    // Check for JSON
984    if ((trimmed.starts_with('{') && trimmed.ends_with('}'))
985        || (trimmed.starts_with('[') && trimmed.ends_with(']')))
986        && serde_json::from_str::<serde_json::Value>(trimmed).is_ok()
987    {
988        return ContentType::Json;
989    }
990
991    // Check for YAML (common patterns)
992    // Require either a "---" document separator or multiple lines matching "key: value"
993    // where "key" is a simple identifier (no spaces before colon, no URLs).
994    if trimmed.contains(":\n") || trimmed.contains(": ") || trimmed.starts_with("---") {
995        let lines: Vec<&str> = trimmed.lines().take(10).collect();
996        let yaml_key_count = lines
997            .iter()
998            .filter(|line| {
999                let l = line.trim();
1000                // A YAML key line: starts with a word-like key followed by ': '
1001                // Excludes URLs (://), empty lines, comments, prose (key must be identifier-like)
1002                !l.is_empty()
1003                    && !l.starts_with('#')
1004                    && !l.contains("://")
1005                    && l.contains(": ")
1006                    && l.split(": ").next().is_some_and(|key| {
1007                        let k = key.trim_start_matches("- ");
1008                        !k.is_empty()
1009                            && k.chars()
1010                                .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
1011                    })
1012            })
1013            .count();
1014        // Require at least 2 YAML-like key lines or starts with ---
1015        if yaml_key_count >= 2 || (trimmed.starts_with("---") && yaml_key_count >= 1) {
1016            return ContentType::Yaml;
1017        }
1018    }
1019
1020    // Check for Markdown (has headers)
1021    if trimmed.contains("\n# ") || trimmed.starts_with("# ") || trimmed.contains("\n## ") {
1022        return ContentType::Markdown;
1023    }
1024
1025    // Check for code (common patterns at start of lines to avoid false positives
1026    // from English prose like "let me explain" or "import regulations")
1027    let code_patterns = [
1028        "fn ",
1029        "pub fn ",
1030        "def ",
1031        "class ",
1032        "function ",
1033        "const ",
1034        "let ",
1035        "var ",
1036        "import ",
1037        "export ",
1038        "struct ",
1039        "impl ",
1040        "trait ",
1041    ];
1042    let has_code_pattern = trimmed.lines().any(|line| {
1043        let l = line.trim();
1044        code_patterns.iter().any(|p| l.starts_with(p))
1045    });
1046    if has_code_pattern {
1047        return ContentType::Code;
1048    }
1049
1050    ContentType::Text
1051}
1052
1053// ==================== Arrow Conversion Helpers ====================
1054
1055fn item_to_batch(item: &Item) -> Result<RecordBatch> {
1056    let schema = Arc::new(item_schema());
1057
1058    let id = StringArray::from(vec![item.id.as_str()]);
1059    let content = StringArray::from(vec![item.content.as_str()]);
1060    let project_id = StringArray::from(vec![item.project_id.as_deref()]);
1061    let is_chunked = BooleanArray::from(vec![item.is_chunked]);
1062    let created_at = Int64Array::from(vec![item.created_at.timestamp()]);
1063
1064    let vector = create_embedding_array(&item.embedding)?;
1065
1066    RecordBatch::try_new(
1067        schema,
1068        vec![
1069            Arc::new(id),
1070            Arc::new(content),
1071            Arc::new(project_id),
1072            Arc::new(is_chunked),
1073            Arc::new(created_at),
1074            Arc::new(vector),
1075        ],
1076    )
1077    .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1078}
1079
1080fn batch_to_items(batch: &RecordBatch) -> Result<Vec<Item>> {
1081    let mut items = Vec::new();
1082
1083    let id_col = batch
1084        .column_by_name("id")
1085        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1086        .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1087
1088    let content_col = batch
1089        .column_by_name("content")
1090        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1091        .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1092
1093    let project_id_col = batch
1094        .column_by_name("project_id")
1095        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1096
1097    let is_chunked_col = batch
1098        .column_by_name("is_chunked")
1099        .and_then(|c| c.as_any().downcast_ref::<BooleanArray>());
1100
1101    let created_at_col = batch
1102        .column_by_name("created_at")
1103        .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
1104
1105    let vector_col = batch
1106        .column_by_name("vector")
1107        .and_then(|c| c.as_any().downcast_ref::<FixedSizeListArray>());
1108
1109    for i in 0..batch.num_rows() {
1110        let id = id_col.value(i).to_string();
1111        let content = content_col.value(i).to_string();
1112
1113        let project_id = project_id_col.and_then(|c| {
1114            if c.is_null(i) {
1115                None
1116            } else {
1117                Some(c.value(i).to_string())
1118            }
1119        });
1120
1121        let is_chunked = is_chunked_col.map(|c| c.value(i)).unwrap_or(false);
1122
1123        let created_at = created_at_col
1124            .map(|c| {
1125                Utc.timestamp_opt(c.value(i), 0)
1126                    .single()
1127                    .unwrap_or_else(Utc::now)
1128            })
1129            .unwrap_or_else(Utc::now);
1130
1131        let embedding = vector_col
1132            .and_then(|col| {
1133                let value = col.value(i);
1134                value
1135                    .as_any()
1136                    .downcast_ref::<Float32Array>()
1137                    .map(|arr| arr.values().to_vec())
1138            })
1139            .unwrap_or_default();
1140
1141        let item = Item {
1142            id,
1143            content,
1144            embedding,
1145            project_id,
1146            is_chunked,
1147            created_at,
1148        };
1149
1150        items.push(item);
1151    }
1152
1153    Ok(items)
1154}
1155
1156fn chunk_to_batch(chunk: &Chunk) -> Result<RecordBatch> {
1157    let schema = Arc::new(chunk_schema());
1158
1159    let id = StringArray::from(vec![chunk.id.as_str()]);
1160    let item_id = StringArray::from(vec![chunk.item_id.as_str()]);
1161    let chunk_index = Int32Array::from(vec![i32::try_from(chunk.chunk_index).unwrap_or(i32::MAX)]);
1162    let content = StringArray::from(vec![chunk.content.as_str()]);
1163    let context = StringArray::from(vec![chunk.context.as_deref()]);
1164
1165    let vector = create_embedding_array(&chunk.embedding)?;
1166
1167    RecordBatch::try_new(
1168        schema,
1169        vec![
1170            Arc::new(id),
1171            Arc::new(item_id),
1172            Arc::new(chunk_index),
1173            Arc::new(content),
1174            Arc::new(context),
1175            Arc::new(vector),
1176        ],
1177    )
1178    .map_err(|e| SedimentError::Database(format!("Failed to create batch: {}", e)))
1179}
1180
1181fn batch_to_chunks(batch: &RecordBatch) -> Result<Vec<Chunk>> {
1182    let mut chunks = Vec::new();
1183
1184    let id_col = batch
1185        .column_by_name("id")
1186        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1187        .ok_or_else(|| SedimentError::Database("Missing id column".to_string()))?;
1188
1189    let item_id_col = batch
1190        .column_by_name("item_id")
1191        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1192        .ok_or_else(|| SedimentError::Database("Missing item_id column".to_string()))?;
1193
1194    let chunk_index_col = batch
1195        .column_by_name("chunk_index")
1196        .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
1197        .ok_or_else(|| SedimentError::Database("Missing chunk_index column".to_string()))?;
1198
1199    let content_col = batch
1200        .column_by_name("content")
1201        .and_then(|c| c.as_any().downcast_ref::<StringArray>())
1202        .ok_or_else(|| SedimentError::Database("Missing content column".to_string()))?;
1203
1204    let context_col = batch
1205        .column_by_name("context")
1206        .and_then(|c| c.as_any().downcast_ref::<StringArray>());
1207
1208    for i in 0..batch.num_rows() {
1209        let id = id_col.value(i).to_string();
1210        let item_id = item_id_col.value(i).to_string();
1211        let chunk_index = chunk_index_col.value(i) as usize;
1212        let content = content_col.value(i).to_string();
1213        let context = context_col.and_then(|c| {
1214            if c.is_null(i) {
1215                None
1216            } else {
1217                Some(c.value(i).to_string())
1218            }
1219        });
1220
1221        let chunk = Chunk {
1222            id,
1223            item_id,
1224            chunk_index,
1225            content,
1226            embedding: Vec::new(),
1227            context,
1228        };
1229
1230        chunks.push(chunk);
1231    }
1232
1233    Ok(chunks)
1234}
1235
1236fn create_embedding_array(embedding: &[f32]) -> Result<FixedSizeListArray> {
1237    let values = Float32Array::from(embedding.to_vec());
1238    let field = Arc::new(Field::new("item", DataType::Float32, true));
1239
1240    FixedSizeListArray::try_new(field, EMBEDDING_DIM as i32, Arc::new(values), None)
1241        .map_err(|e| SedimentError::Database(format!("Failed to create vector: {}", e)))
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246    use super::*;
1247
1248    #[test]
1249    fn test_score_with_decay_fresh_item() {
1250        let now = 1700000000i64;
1251        let created = now; // just created
1252        let score = score_with_decay(0.8, now, created, 0, None);
1253        // freshness = 1.0, frequency = 1.0 + 0.1 * ln(1) = 1.0
1254        let expected = 0.8 * 1.0 * 1.0;
1255        assert!((score - expected).abs() < 0.001, "got {}", score);
1256    }
1257
1258    #[test]
1259    fn test_score_with_decay_30_day_old() {
1260        let now = 1700000000i64;
1261        let created = now - 30 * 86400; // 30 days old
1262        let score = score_with_decay(0.8, now, created, 0, None);
1263        // freshness = 1/(1+1) = 0.5, frequency = 1.0
1264        let expected = 0.8 * 0.5;
1265        assert!((score - expected).abs() < 0.001, "got {}", score);
1266    }
1267
1268    #[test]
1269    fn test_score_with_decay_frequent_access() {
1270        let now = 1700000000i64;
1271        let created = now - 30 * 86400;
1272        let last_accessed = now; // just accessed
1273        let score = score_with_decay(0.8, now, created, 10, Some(last_accessed));
1274        // freshness = 1.0 (just accessed), frequency = 1.0 + 0.1 * ln(11) ≈ 1.2397
1275        let freq = 1.0 + 0.1 * (11.0_f64).ln();
1276        let expected = 0.8 * 1.0 * freq as f32;
1277        assert!((score - expected).abs() < 0.01, "got {}", score);
1278    }
1279
1280    #[test]
1281    fn test_score_with_decay_old_and_unused() {
1282        let now = 1700000000i64;
1283        let created = now - 90 * 86400; // 90 days old
1284        let score = score_with_decay(0.8, now, created, 0, None);
1285        // freshness = 1/(1+3) = 0.25
1286        let expected = 0.8 * 0.25;
1287        assert!((score - expected).abs() < 0.001, "got {}", score);
1288    }
1289
1290    #[test]
1291    fn test_sanitize_sql_string_escapes_quotes_and_backslashes() {
1292        assert_eq!(sanitize_sql_string("hello"), "hello");
1293        assert_eq!(sanitize_sql_string("it's"), "it''s");
1294        assert_eq!(sanitize_sql_string(r"a\'b"), r"a\\''b");
1295        assert_eq!(sanitize_sql_string(r"path\to\file"), r"path\\to\\file");
1296    }
1297
1298    #[test]
1299    fn test_sanitize_sql_string_strips_null_bytes() {
1300        assert_eq!(sanitize_sql_string("abc\0def"), "abcdef");
1301        assert_eq!(sanitize_sql_string("\0' OR 1=1 --"), "'' OR 1=1 ");
1302        assert_eq!(sanitize_sql_string("clean"), "clean");
1303    }
1304
1305    #[test]
1306    fn test_sanitize_sql_string_strips_semicolons() {
1307        assert_eq!(
1308            sanitize_sql_string("a; DROP TABLE items"),
1309            "a DROP TABLE items"
1310        );
1311        assert_eq!(sanitize_sql_string("normal;"), "normal");
1312    }
1313
1314    #[test]
1315    fn test_sanitize_sql_string_strips_comments() {
1316        // Line comments (-- stripped, leaving extra space)
1317        assert_eq!(sanitize_sql_string("val' -- comment"), "val''  comment");
1318        // Block comments (/* stripped, leaving extra space)
1319        assert_eq!(sanitize_sql_string("val' /* block */"), "val''  block */");
1320        // Nested attempts
1321        assert_eq!(sanitize_sql_string("a--b--c"), "abc");
1322    }
1323
1324    #[test]
1325    fn test_sanitize_sql_string_adversarial_inputs() {
1326        // Classic SQL injection
1327        assert_eq!(
1328            sanitize_sql_string("'; DROP TABLE items;--"),
1329            "'' DROP TABLE items"
1330        );
1331        // Unicode escapes (should pass through harmlessly)
1332        assert_eq!(
1333            sanitize_sql_string("hello\u{200B}world"),
1334            "hello\u{200B}world"
1335        );
1336        // Empty string
1337        assert_eq!(sanitize_sql_string(""), "");
1338        // Only special chars
1339        assert_eq!(sanitize_sql_string("\0;\0"), "");
1340    }
1341
1342    #[test]
1343    fn test_is_valid_id() {
1344        // Valid UUIDs
1345        assert!(is_valid_id("550e8400-e29b-41d4-a716-446655440000"));
1346        assert!(is_valid_id("abcdef0123456789"));
1347        // Invalid
1348        assert!(!is_valid_id(""));
1349        assert!(!is_valid_id("'; DROP TABLE items;--"));
1350        assert!(!is_valid_id("hello world"));
1351        assert!(!is_valid_id("abc\0def"));
1352        // Too long
1353        assert!(!is_valid_id(&"a".repeat(65)));
1354    }
1355
1356    #[test]
1357    fn test_detect_content_type_yaml_not_prose() {
1358        // Fix #11: Prose with colons should NOT be detected as YAML
1359        let prose = "Dear John:\nI wanted to write you about something.\nSubject: important matter";
1360        let detected = detect_content_type(prose);
1361        assert_ne!(
1362            detected,
1363            ContentType::Yaml,
1364            "Prose with colons should not be detected as YAML"
1365        );
1366
1367        // Actual YAML should still be detected
1368        let yaml = "server: localhost\nport: 8080\ndatabase: mydb";
1369        let detected = detect_content_type(yaml);
1370        assert_eq!(detected, ContentType::Yaml);
1371    }
1372
1373    #[test]
1374    fn test_detect_content_type_yaml_with_separator() {
1375        let yaml = "---\nname: test\nversion: 1.0";
1376        let detected = detect_content_type(yaml);
1377        assert_eq!(detected, ContentType::Yaml);
1378    }
1379
1380    #[test]
1381    fn test_chunk_threshold_uses_chars_not_bytes() {
1382        // Bug #12: CHUNK_THRESHOLD should compare character count, not byte count.
1383        // 500 emoji chars = 2000 bytes. Should NOT exceed 1000-char threshold.
1384        let emoji_content = "😀".repeat(500);
1385        assert_eq!(emoji_content.chars().count(), 500);
1386        assert_eq!(emoji_content.len(), 2000); // 4 bytes per emoji
1387
1388        let should_chunk = emoji_content.chars().count() > CHUNK_THRESHOLD;
1389        assert!(
1390            !should_chunk,
1391            "500 chars should not exceed 1000-char threshold"
1392        );
1393
1394        // 1001 chars should trigger chunking
1395        let long_content = "a".repeat(1001);
1396        let should_chunk = long_content.chars().count() > CHUNK_THRESHOLD;
1397        assert!(should_chunk, "1001 chars should exceed 1000-char threshold");
1398    }
1399
1400    #[test]
1401    fn test_schema_version() {
1402        // Ensure schema version is set
1403        let version = SCHEMA_VERSION;
1404        assert!(version >= 2, "Schema version should be at least 2");
1405    }
1406}