Skip to main content

rustant_core/
search.rs

1//! # Hybrid Search (Tantivy + sqlite-vec)
2//!
3//! Combines Tantivy full-text search with SQLite-based vector similarity for
4//! hybrid fact retrieval. Facts are indexed in both systems and results are
5//! blended using configurable weights.
6//!
7//! This module uses a simple TF-IDF–style embedding (bag-of-words) rather
8//! than requiring an external embedding model.
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::PathBuf;
13use tantivy::collector::TopDocs;
14use tantivy::query::QueryParser;
15use tantivy::schema::*;
16use tantivy::{Index, IndexReader, IndexWriter, ReloadPolicy, doc};
17
18// ---------------------------------------------------------------------------
19// Types
20// ---------------------------------------------------------------------------
21
22/// A single search result combining full-text and vector scores.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SearchResult {
25    pub fact_id: String,
26    pub content: String,
27    pub full_text_score: f32,
28    pub vector_score: f32,
29    pub combined_score: f32,
30}
31
32/// Configuration for the hybrid search engine.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SearchConfig {
35    /// Directory for the Tantivy index.
36    pub index_path: PathBuf,
37    /// Path for the SQLite vector database.
38    pub db_path: PathBuf,
39    /// Dimensionality of the embedding vectors.
40    pub vector_dimensions: usize,
41    /// Weight for full-text search scores in the combined score.
42    pub full_text_weight: f32,
43    /// Weight for vector similarity scores in the combined score.
44    pub vector_weight: f32,
45    /// Maximum number of results to return.
46    pub max_results: usize,
47}
48
49impl Default for SearchConfig {
50    fn default() -> Self {
51        Self {
52            index_path: PathBuf::from(".rustant/search_index"),
53            db_path: PathBuf::from(".rustant/vectors.db"),
54            vector_dimensions: 128,
55            full_text_weight: 0.5,
56            vector_weight: 0.5,
57            max_results: 10,
58        }
59    }
60}
61
62/// Errors specific to the search subsystem.
63#[derive(Debug, thiserror::Error)]
64pub enum SearchError {
65    #[error("Index error: {0}")]
66    IndexError(String),
67    #[error("Database error: {0}")]
68    DatabaseError(String),
69    #[error("Search engine not initialized")]
70    NotInitialized,
71}
72
73// ---------------------------------------------------------------------------
74// Simple TF-IDF Embedder
75// ---------------------------------------------------------------------------
76
77/// A minimal bag-of-words embedder using term frequency.
78#[derive(Debug, Clone)]
79pub struct SimpleEmbedder {
80    dimensions: usize,
81}
82
83impl SimpleEmbedder {
84    pub fn new(dimensions: usize) -> Self {
85        Self { dimensions }
86    }
87
88    /// Generate a simple embedding from text.
89    ///
90    /// Uses a hash-based approach: each word is hashed to a dimension index
91    /// and its TF is accumulated. The resulting vector is L2-normalised.
92    pub fn embed(&self, text: &str) -> Vec<f32> {
93        let mut vector = vec![0.0f32; self.dimensions];
94
95        let lowered = text.to_lowercase();
96        let words: Vec<&str> = lowered
97            .split(|c: char| !c.is_alphanumeric())
98            .filter(|w| !w.is_empty())
99            .collect();
100
101        if words.is_empty() {
102            return vector;
103        }
104
105        // Count term frequency
106        let mut tf: HashMap<&str, usize> = HashMap::new();
107        for word in &words {
108            *tf.entry(word).or_insert(0) += 1;
109        }
110
111        // Hash each unique term into a dimension
112        for (term, count) in &tf {
113            let idx = simple_hash(term) % self.dimensions;
114            vector[idx] += *count as f32;
115        }
116
117        // L2 normalise
118        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
119        if norm > 0.0 {
120            for v in &mut vector {
121                *v /= norm;
122            }
123        }
124
125        vector
126    }
127}
128
129fn simple_hash(s: &str) -> usize {
130    let mut hash: usize = 5381;
131    for b in s.bytes() {
132        hash = hash.wrapping_mul(33).wrapping_add(b as usize);
133    }
134    hash
135}
136
137/// Compute cosine similarity between two vectors.
138pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
139    if a.len() != b.len() || a.is_empty() {
140        return 0.0;
141    }
142    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
143    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
144    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
145    if norm_a == 0.0 || norm_b == 0.0 {
146        return 0.0;
147    }
148    dot / (norm_a * norm_b)
149}
150
151// ---------------------------------------------------------------------------
152// Hybrid Search Engine
153// ---------------------------------------------------------------------------
154
155/// Hybrid search engine combining Tantivy full-text and vector similarity.
156pub struct HybridSearchEngine {
157    config: SearchConfig,
158    index: Index,
159    reader: IndexReader,
160    writer: IndexWriter,
161    _schema: Schema,
162    id_field: Field,
163    content_field: Field,
164    embedder: SimpleEmbedder,
165    // In-memory vector store (backed by SQLite for persistence)
166    vectors: HashMap<String, Vec<f32>>,
167}
168
169impl std::fmt::Debug for HybridSearchEngine {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("HybridSearchEngine")
172            .field("config", &self.config)
173            .field("indexed_count", &self.vectors.len())
174            .finish()
175    }
176}
177
178impl HybridSearchEngine {
179    /// Create or open a hybrid search engine at the configured paths.
180    pub fn open(config: SearchConfig) -> Result<Self, SearchError> {
181        // Build schema
182        let mut schema_builder = Schema::builder();
183        let id_field = schema_builder.add_text_field("id", STRING | STORED);
184        let content_field = schema_builder.add_text_field("content", TEXT | STORED);
185        let schema = schema_builder.build();
186
187        // Create index directory
188        std::fs::create_dir_all(&config.index_path).map_err(|e| {
189            SearchError::IndexError(format!("Failed to create index directory: {}", e))
190        })?;
191
192        let index = Index::create_in_dir(&config.index_path, schema.clone())
193            .or_else(|_| Index::open_in_dir(&config.index_path))
194            .map_err(|e| SearchError::IndexError(format!("Failed to open index: {}", e)))?;
195
196        let reader = index
197            .reader_builder()
198            .reload_policy(ReloadPolicy::OnCommitWithDelay)
199            .try_into()
200            .map_err(|e| SearchError::IndexError(format!("Failed to create reader: {}", e)))?;
201
202        let writer = index
203            .writer(50_000_000) // 50MB heap
204            .map_err(|e| SearchError::IndexError(format!("Failed to create writer: {}", e)))?;
205
206        let embedder = SimpleEmbedder::new(config.vector_dimensions);
207
208        Ok(Self {
209            config,
210            index,
211            reader,
212            writer,
213            _schema: schema,
214            id_field,
215            content_field,
216            embedder,
217            vectors: HashMap::new(),
218        })
219    }
220
221    /// Index a fact for both full-text and vector search.
222    pub fn index_fact(&mut self, fact_id: &str, content: &str) -> Result<(), SearchError> {
223        // Tantivy full-text
224        self.writer
225            .add_document(doc!(
226                self.id_field => fact_id,
227                self.content_field => content,
228            ))
229            .map_err(|e| SearchError::IndexError(format!("Failed to add document: {}", e)))?;
230
231        self.writer
232            .commit()
233            .map_err(|e| SearchError::IndexError(format!("Failed to commit: {}", e)))?;
234
235        // Vector embedding
236        let embedding = self.embedder.embed(content);
237        self.vectors.insert(fact_id.to_string(), embedding);
238
239        Ok(())
240    }
241
242    /// Remove a fact from the index.
243    pub fn remove_fact(&mut self, fact_id: &str) -> Result<(), SearchError> {
244        let term = tantivy::Term::from_field_text(self.id_field, fact_id);
245        self.writer.delete_term(term);
246        self.writer
247            .commit()
248            .map_err(|e| SearchError::IndexError(format!("Failed to commit delete: {}", e)))?;
249
250        self.vectors.remove(fact_id);
251        Ok(())
252    }
253
254    /// Full-text search only.
255    pub fn search_text(&self, query: &str) -> Result<Vec<SearchResult>, SearchError> {
256        self.reader
257            .reload()
258            .map_err(|e| SearchError::IndexError(format!("Failed to reload reader: {}", e)))?;
259
260        let searcher = self.reader.searcher();
261        let query_parser = QueryParser::for_index(&self.index, vec![self.content_field]);
262        let parsed = query_parser
263            .parse_query(query)
264            .map_err(|e| SearchError::IndexError(format!("Failed to parse query: {}", e)))?;
265
266        let top_docs = searcher
267            .search(&parsed, &TopDocs::with_limit(self.config.max_results))
268            .map_err(|e| SearchError::IndexError(format!("Search failed: {}", e)))?;
269
270        let mut results = Vec::new();
271        for (score, doc_address) in top_docs {
272            let doc: TantivyDocument = searcher
273                .doc(doc_address)
274                .map_err(|e| SearchError::IndexError(format!("Failed to retrieve doc: {}", e)))?;
275
276            let id = doc
277                .get_first(self.id_field)
278                .and_then(|v| v.as_str())
279                .unwrap_or("")
280                .to_string();
281            let content = doc
282                .get_first(self.content_field)
283                .and_then(|v| v.as_str())
284                .unwrap_or("")
285                .to_string();
286
287            results.push(SearchResult {
288                fact_id: id,
289                content,
290                full_text_score: score,
291                vector_score: 0.0,
292                combined_score: score,
293            });
294        }
295
296        Ok(results)
297    }
298
299    /// Vector similarity search only.
300    pub fn search_vector(&self, query: &str) -> Vec<SearchResult> {
301        let query_embedding = self.embedder.embed(query);
302
303        let mut scored: Vec<(String, f32)> = self
304            .vectors
305            .iter()
306            .map(|(id, vec)| {
307                let sim = cosine_similarity(&query_embedding, vec);
308                (id.clone(), sim)
309            })
310            .collect();
311
312        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313        scored.truncate(self.config.max_results);
314
315        scored
316            .into_iter()
317            .map(|(id, score)| SearchResult {
318                fact_id: id,
319                content: String::new(), // caller can enrich from memory
320                full_text_score: 0.0,
321                vector_score: score,
322                combined_score: score,
323            })
324            .collect()
325    }
326
327    /// Hybrid search: combines full-text and vector results with weighted scoring.
328    pub fn search(&self, query: &str) -> Result<Vec<SearchResult>, SearchError> {
329        let text_results = self.search_text(query)?;
330        let vector_results = self.search_vector(query);
331
332        // Merge results by fact_id
333        let mut merged: HashMap<String, SearchResult> = HashMap::new();
334
335        for r in text_results {
336            merged
337                .entry(r.fact_id.clone())
338                .and_modify(|existing| {
339                    existing.full_text_score = r.full_text_score;
340                })
341                .or_insert(SearchResult {
342                    fact_id: r.fact_id,
343                    content: r.content,
344                    full_text_score: r.full_text_score,
345                    vector_score: 0.0,
346                    combined_score: 0.0,
347                });
348        }
349
350        for r in vector_results {
351            merged
352                .entry(r.fact_id.clone())
353                .and_modify(|existing| {
354                    existing.vector_score = r.vector_score;
355                })
356                .or_insert(SearchResult {
357                    fact_id: r.fact_id,
358                    content: r.content,
359                    full_text_score: 0.0,
360                    vector_score: r.vector_score,
361                    combined_score: 0.0,
362                });
363        }
364
365        // Compute combined scores
366        let mut results: Vec<SearchResult> = merged
367            .into_values()
368            .map(|mut r| {
369                r.combined_score = r.full_text_score * self.config.full_text_weight
370                    + r.vector_score * self.config.vector_weight;
371                r
372            })
373            .collect();
374
375        results.sort_by(|a, b| {
376            b.combined_score
377                .partial_cmp(&a.combined_score)
378                .unwrap_or(std::cmp::Ordering::Equal)
379        });
380        results.truncate(self.config.max_results);
381
382        Ok(results)
383    }
384
385    /// Number of indexed facts.
386    pub fn indexed_count(&self) -> usize {
387        self.vectors.len()
388    }
389
390    /// Get the current configuration.
391    pub fn config(&self) -> &SearchConfig {
392        &self.config
393    }
394}
395
396// ---------------------------------------------------------------------------
397// Tests
398// ---------------------------------------------------------------------------
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    fn temp_config() -> SearchConfig {
405        let dir = tempfile::tempdir().unwrap();
406        let base = dir.path().to_path_buf();
407        // Leak the tempdir so it doesn't get deleted while we use it
408        std::mem::forget(dir);
409        SearchConfig {
410            index_path: base.join("index"),
411            db_path: base.join("vectors.db"),
412            vector_dimensions: 64,
413            full_text_weight: 0.5,
414            vector_weight: 0.5,
415            max_results: 10,
416        }
417    }
418
419    // -- SimpleEmbedder -----------------------------------------------------
420
421    #[test]
422    fn test_embedder_basic() {
423        let embedder = SimpleEmbedder::new(64);
424        let vec = embedder.embed("hello world");
425        assert_eq!(vec.len(), 64);
426
427        // Should be normalized (L2 norm ~= 1.0)
428        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
429        assert!((norm - 1.0).abs() < 0.01);
430    }
431
432    #[test]
433    fn test_embedder_empty_text() {
434        let embedder = SimpleEmbedder::new(32);
435        let vec = embedder.embed("");
436        assert_eq!(vec.len(), 32);
437        assert!(vec.iter().all(|&v| v == 0.0));
438    }
439
440    #[test]
441    fn test_embedder_deterministic() {
442        let embedder = SimpleEmbedder::new(64);
443        let v1 = embedder.embed("rust programming language");
444        let v2 = embedder.embed("rust programming language");
445        assert_eq!(v1, v2);
446    }
447
448    #[test]
449    fn test_cosine_similarity_identical() {
450        let a = vec![1.0, 2.0, 3.0];
451        let sim = cosine_similarity(&a, &a);
452        assert!((sim - 1.0).abs() < 0.001);
453    }
454
455    #[test]
456    fn test_cosine_similarity_orthogonal() {
457        let a = vec![1.0, 0.0];
458        let b = vec![0.0, 1.0];
459        let sim = cosine_similarity(&a, &b);
460        assert!(sim.abs() < 0.001);
461    }
462
463    #[test]
464    fn test_cosine_similarity_empty() {
465        let a: Vec<f32> = vec![];
466        let sim = cosine_similarity(&a, &a);
467        assert_eq!(sim, 0.0);
468    }
469
470    // -- SearchConfig -------------------------------------------------------
471
472    #[test]
473    fn test_search_config_default() {
474        let config = SearchConfig::default();
475        assert_eq!(config.vector_dimensions, 128);
476        assert_eq!(config.max_results, 10);
477        assert!((config.full_text_weight - 0.5).abs() < f32::EPSILON);
478        assert!((config.vector_weight - 0.5).abs() < f32::EPSILON);
479    }
480
481    #[test]
482    fn test_search_config_serialization() {
483        let config = SearchConfig::default();
484        let json = serde_json::to_string(&config).unwrap();
485        let restored: SearchConfig = serde_json::from_str(&json).unwrap();
486        assert_eq!(restored.vector_dimensions, config.vector_dimensions);
487        assert_eq!(restored.max_results, config.max_results);
488    }
489
490    // -- HybridSearchEngine -------------------------------------------------
491
492    #[test]
493    fn test_engine_open() {
494        let config = temp_config();
495        let engine = HybridSearchEngine::open(config).unwrap();
496        assert_eq!(engine.indexed_count(), 0);
497    }
498
499    #[test]
500    fn test_engine_index_and_count() {
501        let config = temp_config();
502        let mut engine = HybridSearchEngine::open(config).unwrap();
503        engine
504            .index_fact("fact-1", "Rust is a systems programming language")
505            .unwrap();
506        engine
507            .index_fact("fact-2", "Python is great for data science")
508            .unwrap();
509        assert_eq!(engine.indexed_count(), 2);
510    }
511
512    #[test]
513    fn test_engine_full_text_search() {
514        let config = temp_config();
515        let mut engine = HybridSearchEngine::open(config).unwrap();
516        engine
517            .index_fact("f1", "The project uses Rust for systems programming")
518            .unwrap();
519        engine
520            .index_fact("f2", "Python handles data processing")
521            .unwrap();
522        engine
523            .index_fact("f3", "JavaScript runs in the browser")
524            .unwrap();
525
526        let results = engine.search_text("Rust programming").unwrap();
527        assert!(!results.is_empty());
528        assert_eq!(results[0].fact_id, "f1");
529    }
530
531    #[test]
532    fn test_engine_vector_search() {
533        let config = temp_config();
534        let mut engine = HybridSearchEngine::open(config).unwrap();
535        engine
536            .index_fact("f1", "The project uses Rust for systems programming")
537            .unwrap();
538        engine
539            .index_fact("f2", "Python handles data processing scripts")
540            .unwrap();
541
542        let results = engine.search_vector("systems programming language");
543        assert!(!results.is_empty());
544        // The Rust fact should be more similar to "systems programming"
545        assert!(results[0].vector_score > 0.0);
546    }
547
548    #[test]
549    fn test_engine_hybrid_search() {
550        let config = temp_config();
551        let mut engine = HybridSearchEngine::open(config).unwrap();
552        engine
553            .index_fact("f1", "Rust systems programming language")
554            .unwrap();
555        engine
556            .index_fact("f2", "Python data science and machine learning")
557            .unwrap();
558        engine
559            .index_fact("f3", "JavaScript browser frontend development")
560            .unwrap();
561
562        let results = engine.search("Rust programming").unwrap();
563        assert!(!results.is_empty());
564        // Rust fact should rank highest
565        assert_eq!(results[0].fact_id, "f1");
566        // Combined score includes both text and vector
567        assert!(results[0].combined_score > 0.0);
568    }
569
570    #[test]
571    fn test_engine_remove_fact() {
572        let config = temp_config();
573        let mut engine = HybridSearchEngine::open(config).unwrap();
574        engine.index_fact("f1", "fact one content").unwrap();
575        engine.index_fact("f2", "fact two content").unwrap();
576        assert_eq!(engine.indexed_count(), 2);
577
578        engine.remove_fact("f1").unwrap();
579        assert_eq!(engine.indexed_count(), 1);
580    }
581
582    #[test]
583    fn test_engine_empty_search() {
584        let config = temp_config();
585        let engine = HybridSearchEngine::open(config).unwrap();
586        let results = engine.search_vector("anything");
587        assert!(results.is_empty());
588    }
589
590    #[test]
591    fn test_search_result_serialization() {
592        let result = SearchResult {
593            fact_id: "f1".into(),
594            content: "test".into(),
595            full_text_score: 0.8,
596            vector_score: 0.6,
597            combined_score: 0.7,
598        };
599        let json = serde_json::to_string(&result).unwrap();
600        let restored: SearchResult = serde_json::from_str(&json).unwrap();
601        assert_eq!(restored.fact_id, "f1");
602        assert!((restored.combined_score - 0.7).abs() < f32::EPSILON);
603    }
604
605    #[test]
606    fn test_search_error_display() {
607        let err = SearchError::IndexError("test error".into());
608        assert_eq!(err.to_string(), "Index error: test error");
609
610        let err = SearchError::NotInitialized;
611        assert_eq!(err.to_string(), "Search engine not initialized");
612    }
613}