1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::PathBuf;
13use tantivy::collector::TopDocs;
14use tantivy::query::QueryParser;
15use tantivy::schema::*;
16use tantivy::{Index, IndexReader, IndexWriter, ReloadPolicy, doc};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SearchResult {
25 pub fact_id: String,
26 pub content: String,
27 pub full_text_score: f32,
28 pub vector_score: f32,
29 pub combined_score: f32,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SearchConfig {
35 pub index_path: PathBuf,
37 pub db_path: PathBuf,
39 pub vector_dimensions: usize,
41 pub full_text_weight: f32,
43 pub vector_weight: f32,
45 pub max_results: usize,
47}
48
49impl Default for SearchConfig {
50 fn default() -> Self {
51 Self {
52 index_path: PathBuf::from(".rustant/search_index"),
53 db_path: PathBuf::from(".rustant/vectors.db"),
54 vector_dimensions: 128,
55 full_text_weight: 0.5,
56 vector_weight: 0.5,
57 max_results: 10,
58 }
59 }
60}
61
62#[derive(Debug, thiserror::Error)]
64pub enum SearchError {
65 #[error("Index error: {0}")]
66 IndexError(String),
67 #[error("Database error: {0}")]
68 DatabaseError(String),
69 #[error("Search engine not initialized")]
70 NotInitialized,
71}
72
73#[derive(Debug, Clone)]
79pub struct SimpleEmbedder {
80 dimensions: usize,
81}
82
83impl SimpleEmbedder {
84 pub fn new(dimensions: usize) -> Self {
85 Self { dimensions }
86 }
87
88 pub fn embed(&self, text: &str) -> Vec<f32> {
93 let mut vector = vec![0.0f32; self.dimensions];
94
95 let lowered = text.to_lowercase();
96 let words: Vec<&str> = lowered
97 .split(|c: char| !c.is_alphanumeric())
98 .filter(|w| !w.is_empty())
99 .collect();
100
101 if words.is_empty() {
102 return vector;
103 }
104
105 let mut tf: HashMap<&str, usize> = HashMap::new();
107 for word in &words {
108 *tf.entry(word).or_insert(0) += 1;
109 }
110
111 for (term, count) in &tf {
113 let idx = simple_hash(term) % self.dimensions;
114 vector[idx] += *count as f32;
115 }
116
117 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
119 if norm > 0.0 {
120 for v in &mut vector {
121 *v /= norm;
122 }
123 }
124
125 vector
126 }
127}
128
129fn simple_hash(s: &str) -> usize {
130 let mut hash: usize = 5381;
131 for b in s.bytes() {
132 hash = hash.wrapping_mul(33).wrapping_add(b as usize);
133 }
134 hash
135}
136
137pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
139 if a.len() != b.len() || a.is_empty() {
140 return 0.0;
141 }
142 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
143 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
144 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
145 if norm_a == 0.0 || norm_b == 0.0 {
146 return 0.0;
147 }
148 dot / (norm_a * norm_b)
149}
150
151pub struct HybridSearchEngine {
157 config: SearchConfig,
158 index: Index,
159 reader: IndexReader,
160 writer: IndexWriter,
161 _schema: Schema,
162 id_field: Field,
163 content_field: Field,
164 embedder: SimpleEmbedder,
165 vectors: HashMap<String, Vec<f32>>,
167}
168
169impl std::fmt::Debug for HybridSearchEngine {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct("HybridSearchEngine")
172 .field("config", &self.config)
173 .field("indexed_count", &self.vectors.len())
174 .finish()
175 }
176}
177
178impl HybridSearchEngine {
179 pub fn open(config: SearchConfig) -> Result<Self, SearchError> {
181 let mut schema_builder = Schema::builder();
183 let id_field = schema_builder.add_text_field("id", STRING | STORED);
184 let content_field = schema_builder.add_text_field("content", TEXT | STORED);
185 let schema = schema_builder.build();
186
187 std::fs::create_dir_all(&config.index_path).map_err(|e| {
189 SearchError::IndexError(format!("Failed to create index directory: {}", e))
190 })?;
191
192 let index = Index::create_in_dir(&config.index_path, schema.clone())
193 .or_else(|_| Index::open_in_dir(&config.index_path))
194 .map_err(|e| SearchError::IndexError(format!("Failed to open index: {}", e)))?;
195
196 let reader = index
197 .reader_builder()
198 .reload_policy(ReloadPolicy::OnCommitWithDelay)
199 .try_into()
200 .map_err(|e| SearchError::IndexError(format!("Failed to create reader: {}", e)))?;
201
202 let writer = index
203 .writer(50_000_000) .map_err(|e| SearchError::IndexError(format!("Failed to create writer: {}", e)))?;
205
206 let embedder = SimpleEmbedder::new(config.vector_dimensions);
207
208 Ok(Self {
209 config,
210 index,
211 reader,
212 writer,
213 _schema: schema,
214 id_field,
215 content_field,
216 embedder,
217 vectors: HashMap::new(),
218 })
219 }
220
221 pub fn index_fact(&mut self, fact_id: &str, content: &str) -> Result<(), SearchError> {
223 self.writer
225 .add_document(doc!(
226 self.id_field => fact_id,
227 self.content_field => content,
228 ))
229 .map_err(|e| SearchError::IndexError(format!("Failed to add document: {}", e)))?;
230
231 self.writer
232 .commit()
233 .map_err(|e| SearchError::IndexError(format!("Failed to commit: {}", e)))?;
234
235 let embedding = self.embedder.embed(content);
237 self.vectors.insert(fact_id.to_string(), embedding);
238
239 Ok(())
240 }
241
242 pub fn remove_fact(&mut self, fact_id: &str) -> Result<(), SearchError> {
244 let term = tantivy::Term::from_field_text(self.id_field, fact_id);
245 self.writer.delete_term(term);
246 self.writer
247 .commit()
248 .map_err(|e| SearchError::IndexError(format!("Failed to commit delete: {}", e)))?;
249
250 self.vectors.remove(fact_id);
251 Ok(())
252 }
253
254 pub fn search_text(&self, query: &str) -> Result<Vec<SearchResult>, SearchError> {
256 self.reader
257 .reload()
258 .map_err(|e| SearchError::IndexError(format!("Failed to reload reader: {}", e)))?;
259
260 let searcher = self.reader.searcher();
261 let query_parser = QueryParser::for_index(&self.index, vec![self.content_field]);
262 let parsed = query_parser
263 .parse_query(query)
264 .map_err(|e| SearchError::IndexError(format!("Failed to parse query: {}", e)))?;
265
266 let top_docs = searcher
267 .search(&parsed, &TopDocs::with_limit(self.config.max_results))
268 .map_err(|e| SearchError::IndexError(format!("Search failed: {}", e)))?;
269
270 let mut results = Vec::new();
271 for (score, doc_address) in top_docs {
272 let doc: TantivyDocument = searcher
273 .doc(doc_address)
274 .map_err(|e| SearchError::IndexError(format!("Failed to retrieve doc: {}", e)))?;
275
276 let id = doc
277 .get_first(self.id_field)
278 .and_then(|v| v.as_str())
279 .unwrap_or("")
280 .to_string();
281 let content = doc
282 .get_first(self.content_field)
283 .and_then(|v| v.as_str())
284 .unwrap_or("")
285 .to_string();
286
287 results.push(SearchResult {
288 fact_id: id,
289 content,
290 full_text_score: score,
291 vector_score: 0.0,
292 combined_score: score,
293 });
294 }
295
296 Ok(results)
297 }
298
299 pub fn search_vector(&self, query: &str) -> Vec<SearchResult> {
301 let query_embedding = self.embedder.embed(query);
302
303 let mut scored: Vec<(String, f32)> = self
304 .vectors
305 .iter()
306 .map(|(id, vec)| {
307 let sim = cosine_similarity(&query_embedding, vec);
308 (id.clone(), sim)
309 })
310 .collect();
311
312 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313 scored.truncate(self.config.max_results);
314
315 scored
316 .into_iter()
317 .map(|(id, score)| SearchResult {
318 fact_id: id,
319 content: String::new(), full_text_score: 0.0,
321 vector_score: score,
322 combined_score: score,
323 })
324 .collect()
325 }
326
327 pub fn search(&self, query: &str) -> Result<Vec<SearchResult>, SearchError> {
329 let text_results = self.search_text(query)?;
330 let vector_results = self.search_vector(query);
331
332 let mut merged: HashMap<String, SearchResult> = HashMap::new();
334
335 for r in text_results {
336 merged
337 .entry(r.fact_id.clone())
338 .and_modify(|existing| {
339 existing.full_text_score = r.full_text_score;
340 })
341 .or_insert(SearchResult {
342 fact_id: r.fact_id,
343 content: r.content,
344 full_text_score: r.full_text_score,
345 vector_score: 0.0,
346 combined_score: 0.0,
347 });
348 }
349
350 for r in vector_results {
351 merged
352 .entry(r.fact_id.clone())
353 .and_modify(|existing| {
354 existing.vector_score = r.vector_score;
355 })
356 .or_insert(SearchResult {
357 fact_id: r.fact_id,
358 content: r.content,
359 full_text_score: 0.0,
360 vector_score: r.vector_score,
361 combined_score: 0.0,
362 });
363 }
364
365 let mut results: Vec<SearchResult> = merged
367 .into_values()
368 .map(|mut r| {
369 r.combined_score = r.full_text_score * self.config.full_text_weight
370 + r.vector_score * self.config.vector_weight;
371 r
372 })
373 .collect();
374
375 results.sort_by(|a, b| {
376 b.combined_score
377 .partial_cmp(&a.combined_score)
378 .unwrap_or(std::cmp::Ordering::Equal)
379 });
380 results.truncate(self.config.max_results);
381
382 Ok(results)
383 }
384
385 pub fn indexed_count(&self) -> usize {
387 self.vectors.len()
388 }
389
390 pub fn config(&self) -> &SearchConfig {
392 &self.config
393 }
394}
395
396#[cfg(test)]
401mod tests {
402 use super::*;
403
404 fn temp_config() -> SearchConfig {
405 let dir = tempfile::tempdir().unwrap();
406 let base = dir.path().to_path_buf();
407 std::mem::forget(dir);
409 SearchConfig {
410 index_path: base.join("index"),
411 db_path: base.join("vectors.db"),
412 vector_dimensions: 64,
413 full_text_weight: 0.5,
414 vector_weight: 0.5,
415 max_results: 10,
416 }
417 }
418
419 #[test]
422 fn test_embedder_basic() {
423 let embedder = SimpleEmbedder::new(64);
424 let vec = embedder.embed("hello world");
425 assert_eq!(vec.len(), 64);
426
427 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
429 assert!((norm - 1.0).abs() < 0.01);
430 }
431
432 #[test]
433 fn test_embedder_empty_text() {
434 let embedder = SimpleEmbedder::new(32);
435 let vec = embedder.embed("");
436 assert_eq!(vec.len(), 32);
437 assert!(vec.iter().all(|&v| v == 0.0));
438 }
439
440 #[test]
441 fn test_embedder_deterministic() {
442 let embedder = SimpleEmbedder::new(64);
443 let v1 = embedder.embed("rust programming language");
444 let v2 = embedder.embed("rust programming language");
445 assert_eq!(v1, v2);
446 }
447
448 #[test]
449 fn test_cosine_similarity_identical() {
450 let a = vec![1.0, 2.0, 3.0];
451 let sim = cosine_similarity(&a, &a);
452 assert!((sim - 1.0).abs() < 0.001);
453 }
454
455 #[test]
456 fn test_cosine_similarity_orthogonal() {
457 let a = vec![1.0, 0.0];
458 let b = vec![0.0, 1.0];
459 let sim = cosine_similarity(&a, &b);
460 assert!(sim.abs() < 0.001);
461 }
462
463 #[test]
464 fn test_cosine_similarity_empty() {
465 let a: Vec<f32> = vec![];
466 let sim = cosine_similarity(&a, &a);
467 assert_eq!(sim, 0.0);
468 }
469
470 #[test]
473 fn test_search_config_default() {
474 let config = SearchConfig::default();
475 assert_eq!(config.vector_dimensions, 128);
476 assert_eq!(config.max_results, 10);
477 assert!((config.full_text_weight - 0.5).abs() < f32::EPSILON);
478 assert!((config.vector_weight - 0.5).abs() < f32::EPSILON);
479 }
480
481 #[test]
482 fn test_search_config_serialization() {
483 let config = SearchConfig::default();
484 let json = serde_json::to_string(&config).unwrap();
485 let restored: SearchConfig = serde_json::from_str(&json).unwrap();
486 assert_eq!(restored.vector_dimensions, config.vector_dimensions);
487 assert_eq!(restored.max_results, config.max_results);
488 }
489
490 #[test]
493 fn test_engine_open() {
494 let config = temp_config();
495 let engine = HybridSearchEngine::open(config).unwrap();
496 assert_eq!(engine.indexed_count(), 0);
497 }
498
499 #[test]
500 fn test_engine_index_and_count() {
501 let config = temp_config();
502 let mut engine = HybridSearchEngine::open(config).unwrap();
503 engine
504 .index_fact("fact-1", "Rust is a systems programming language")
505 .unwrap();
506 engine
507 .index_fact("fact-2", "Python is great for data science")
508 .unwrap();
509 assert_eq!(engine.indexed_count(), 2);
510 }
511
512 #[test]
513 fn test_engine_full_text_search() {
514 let config = temp_config();
515 let mut engine = HybridSearchEngine::open(config).unwrap();
516 engine
517 .index_fact("f1", "The project uses Rust for systems programming")
518 .unwrap();
519 engine
520 .index_fact("f2", "Python handles data processing")
521 .unwrap();
522 engine
523 .index_fact("f3", "JavaScript runs in the browser")
524 .unwrap();
525
526 let results = engine.search_text("Rust programming").unwrap();
527 assert!(!results.is_empty());
528 assert_eq!(results[0].fact_id, "f1");
529 }
530
531 #[test]
532 fn test_engine_vector_search() {
533 let config = temp_config();
534 let mut engine = HybridSearchEngine::open(config).unwrap();
535 engine
536 .index_fact("f1", "The project uses Rust for systems programming")
537 .unwrap();
538 engine
539 .index_fact("f2", "Python handles data processing scripts")
540 .unwrap();
541
542 let results = engine.search_vector("systems programming language");
543 assert!(!results.is_empty());
544 assert!(results[0].vector_score > 0.0);
546 }
547
548 #[test]
549 fn test_engine_hybrid_search() {
550 let config = temp_config();
551 let mut engine = HybridSearchEngine::open(config).unwrap();
552 engine
553 .index_fact("f1", "Rust systems programming language")
554 .unwrap();
555 engine
556 .index_fact("f2", "Python data science and machine learning")
557 .unwrap();
558 engine
559 .index_fact("f3", "JavaScript browser frontend development")
560 .unwrap();
561
562 let results = engine.search("Rust programming").unwrap();
563 assert!(!results.is_empty());
564 assert_eq!(results[0].fact_id, "f1");
566 assert!(results[0].combined_score > 0.0);
568 }
569
570 #[test]
571 fn test_engine_remove_fact() {
572 let config = temp_config();
573 let mut engine = HybridSearchEngine::open(config).unwrap();
574 engine.index_fact("f1", "fact one content").unwrap();
575 engine.index_fact("f2", "fact two content").unwrap();
576 assert_eq!(engine.indexed_count(), 2);
577
578 engine.remove_fact("f1").unwrap();
579 assert_eq!(engine.indexed_count(), 1);
580 }
581
582 #[test]
583 fn test_engine_empty_search() {
584 let config = temp_config();
585 let engine = HybridSearchEngine::open(config).unwrap();
586 let results = engine.search_vector("anything");
587 assert!(results.is_empty());
588 }
589
590 #[test]
591 fn test_search_result_serialization() {
592 let result = SearchResult {
593 fact_id: "f1".into(),
594 content: "test".into(),
595 full_text_score: 0.8,
596 vector_score: 0.6,
597 combined_score: 0.7,
598 };
599 let json = serde_json::to_string(&result).unwrap();
600 let restored: SearchResult = serde_json::from_str(&json).unwrap();
601 assert_eq!(restored.fact_id, "f1");
602 assert!((restored.combined_score - 0.7).abs() < f32::EPSILON);
603 }
604
605 #[test]
606 fn test_search_error_display() {
607 let err = SearchError::IndexError("test error".into());
608 assert_eq!(err.to_string(), "Index error: test error");
609
610 let err = SearchError::NotInitialized;
611 assert_eq!(err.to_string(), "Search engine not initialized");
612 }
613}