1use rusqlite::{params, Connection};
2
3use crate::error::Result;
4
5#[derive(Debug, Clone)]
7pub struct SearchResult {
8 pub id: String,
9 pub score: f64,
10 pub snippet: String,
11}
12
13pub struct AdvancedSearch {
17 db: Connection,
18}
19
20const SCHEMA: &str = r#"
23CREATE TABLE IF NOT EXISTS docs (
24 id TEXT PRIMARY KEY,
25 content TEXT NOT NULL
26);
27
28CREATE VIRTUAL TABLE IF NOT EXISTS docs_porter USING fts5(
29 id,
30 content,
31 content='docs',
32 content_rowid='rowid',
33 tokenize='porter ascii'
34);
35
36CREATE VIRTUAL TABLE IF NOT EXISTS docs_trigram USING fts5(
37 id,
38 content,
39 content='docs',
40 content_rowid='rowid',
41 tokenize='trigram'
42);
43
44-- Keep FTS tables in sync with the docs table.
45CREATE TRIGGER IF NOT EXISTS docs_ai AFTER INSERT ON docs BEGIN
46 INSERT INTO docs_porter(rowid, id, content)
47 VALUES (new.rowid, new.id, new.content);
48 INSERT INTO docs_trigram(rowid, id, content)
49 VALUES (new.rowid, new.id, new.content);
50END;
51
52CREATE TRIGGER IF NOT EXISTS docs_ad AFTER DELETE ON docs BEGIN
53 INSERT INTO docs_porter(docs_porter, rowid, id, content)
54 VALUES ('delete', old.rowid, old.id, old.content);
55 INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
56 VALUES ('delete', old.rowid, old.id, old.content);
57END;
58
59CREATE TRIGGER IF NOT EXISTS docs_au AFTER UPDATE ON docs BEGIN
60 INSERT INTO docs_porter(docs_porter, rowid, id, content)
61 VALUES ('delete', old.rowid, old.id, old.content);
62 INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
63 VALUES ('delete', old.rowid, old.id, old.content);
64 INSERT INTO docs_porter(rowid, id, content)
65 VALUES (new.rowid, new.id, new.content);
66 INSERT INTO docs_trigram(rowid, id, content)
67 VALUES (new.rowid, new.id, new.content);
68END;
69"#;
70
71const RRF_K: f64 = 60.0;
76
77fn levenshtein(a: &str, b: &str) -> usize {
81 let a_chars: Vec<char> = a.chars().collect();
82 let b_chars: Vec<char> = b.chars().collect();
83 let m = a_chars.len();
84 let n = b_chars.len();
85
86 let mut prev = (0..=n).collect::<Vec<_>>();
87 let mut curr = vec![0usize; n + 1];
88
89 for i in 1..=m {
90 curr[0] = i;
91 for j in 1..=n {
92 let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
93 curr[j] = (prev[j] + 1)
94 .min(curr[j - 1] + 1)
95 .min(prev[j - 1] + cost);
96 }
97 std::mem::swap(&mut prev, &mut curr);
98 }
99 prev[n]
100}
101
102fn extract_snippet(content: &str, query_terms: &[&str], window: usize) -> String {
105 let lower = content.to_lowercase();
106 let mut best_pos: Option<usize> = None;
108 for term in query_terms {
109 if let Some(pos) = lower.find(&term.to_lowercase()) {
110 best_pos = Some(match best_pos {
111 Some(bp) => bp.min(pos),
112 None => pos,
113 });
114 }
115 }
116
117 let pos = match best_pos {
118 Some(p) => p,
119 None => 0,
120 };
121
122 let start = pos.saturating_sub(window);
123 let end = (pos + window).min(content.len());
124
125 let start = content.floor_char_boundary(start);
127 let end = content.ceil_char_boundary(end);
128
129 let mut snippet = String::new();
130 if start > 0 {
131 snippet.push_str("…");
132 }
133 snippet.push_str(&content[start..end]);
134 if end < content.len() {
135 snippet.push_str("…");
136 }
137 snippet
138}
139
140impl AdvancedSearch {
143 pub fn new() -> Result<Self> {
145 let db = Connection::open_in_memory()?;
146 db.execute_batch(SCHEMA)?;
147 Ok(Self { db })
148 }
149
150 pub fn index(&self, id: &str, content: &str) -> Result<()> {
153 self.db.execute(
154 "INSERT INTO docs (id, content) VALUES (?1, ?2)
155 ON CONFLICT(id) DO UPDATE SET content = excluded.content",
156 params![id, content],
157 )?;
158 Ok(())
159 }
160
161 pub fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
164 let query = query.trim();
165 if query.is_empty() {
166 return Ok(Vec::new());
167 }
168
169 let terms: Vec<&str> = query.split_whitespace().collect();
170
171 let bm25 = self.bm25_search(query);
173
174 let trigram = self.trigram_search(query);
176
177 let mut results = self.reciprocal_rank_fusion(&bm25, &trigram);
179
180 if results.is_empty() {
182 if let Some(corrected) = self.fuzzy_correct(query) {
183 let bm25_c = self.bm25_search(&corrected);
184 let trigram_c = self.trigram_search(&corrected);
185 results = self.reciprocal_rank_fusion(&bm25_c, &trigram_c);
186 }
187 }
188
189 if terms.len() > 1 {
191 self.proximity_rerank(&mut results, &terms);
192 }
193
194 for r in &mut results {
196 if let Ok(content) = self.get_content(&r.id) {
197 r.snippet = extract_snippet(&content, &terms, 80);
198 }
199 }
200
201 Ok(results)
202 }
203
204 fn get_content(&self, id: &str) -> Result<String> {
207 let content: String = self.db.query_row(
208 "SELECT content FROM docs WHERE id = ?1",
209 params![id],
210 |row| row.get(0),
211 )?;
212 Ok(content)
213 }
214
215 fn bm25_search(&self, query: &str) -> Vec<(f64, String)> {
218 let mut stmt = match self.db.prepare(
219 "SELECT d.id, bm25(docs_porter) AS score
220 FROM docs_porter p
221 JOIN docs d ON d.rowid = p.rowid
222 WHERE docs_porter MATCH ?1
223 ORDER BY score",
224 ) {
225 Ok(s) => s,
226 Err(_) => return Vec::new(),
227 };
228
229 let rows = match stmt.query_map(params![query], |row| {
230 Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
231 }) {
232 Ok(r) => r,
233 Err(_) => return Vec::new(),
234 };
235
236 rows.filter_map(|r| r.ok())
237 .map(|(id, score)| (score, id))
238 .collect()
239 }
240
241 fn trigram_search(&self, query: &str) -> Vec<(f64, String)> {
244 if query.len() < 3 {
246 return Vec::new();
247 }
248 let mut stmt = match self.db.prepare(
249 "SELECT d.id, bm25(docs_trigram) AS score
250 FROM docs_trigram t
251 JOIN docs d ON d.rowid = t.rowid
252 WHERE docs_trigram MATCH ?1
253 ORDER BY score",
254 ) {
255 Ok(s) => s,
256 Err(_) => return Vec::new(),
257 };
258
259 let rows = match stmt.query_map(params![query], |row| {
260 Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
261 }) {
262 Ok(r) => r,
263 Err(_) => return Vec::new(),
264 };
265
266 rows.filter_map(|r| r.ok())
267 .map(|(id, score)| (score, id))
268 .collect()
269 }
270
271 fn reciprocal_rank_fusion(
276 &self,
277 a: &[(f64, String)],
278 b: &[(f64, String)],
279 ) -> Vec<SearchResult> {
280 use std::collections::HashMap;
281
282 let mut scores: HashMap<String, f64> = HashMap::new();
283
284 for (rank, (_score, id)) in a.iter().enumerate() {
285 *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
286 }
287 for (rank, (_score, id)) in b.iter().enumerate() {
288 *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
289 }
290
291 let mut results: Vec<SearchResult> = scores
292 .into_iter()
293 .map(|(id, score)| SearchResult {
294 id,
295 score,
296 snippet: String::new(),
297 })
298 .collect();
299
300 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
302 results
303 }
304
305 fn fuzzy_correct(&self, query: &str) -> Option<String> {
308 let vocab = self.vocabulary();
310 if vocab.is_empty() {
311 return None;
312 }
313
314 let terms: Vec<&str> = query.split_whitespace().collect();
315 let mut corrected_terms: Vec<String> = Vec::new();
316 let mut any_corrected = false;
317
318 for term in &terms {
319 let lower = term.to_lowercase();
320 let mut best: Option<(usize, String)> = None;
321 for word in &vocab {
322 let dist = levenshtein(&lower, word);
323 if dist > 0 && dist <= 2 {
324 if best.as_ref().map_or(true, |(d, _)| dist < *d) {
325 best = Some((dist, word.clone()));
326 }
327 }
328 }
329 if let Some((_dist, correction)) = best {
330 corrected_terms.push(correction);
331 any_corrected = true;
332 } else {
333 corrected_terms.push(lower);
334 }
335 }
336
337 if any_corrected {
338 Some(corrected_terms.join(" "))
339 } else {
340 None
341 }
342 }
343
344 fn vocabulary(&self) -> Vec<String> {
346 let mut stmt = match self.db.prepare("SELECT content FROM docs") {
347 Ok(s) => s,
348 Err(_) => return Vec::new(),
349 };
350 let rows = match stmt.query_map([], |row| row.get::<_, String>(0)) {
351 Ok(r) => r,
352 Err(_) => return Vec::new(),
353 };
354
355 let mut words = std::collections::HashSet::new();
356 for row in rows.flatten() {
357 for word in row.split_whitespace() {
358 let w: String = word
359 .chars()
360 .filter(|c| c.is_alphanumeric())
361 .collect::<String>()
362 .to_lowercase();
363 if w.len() >= 2 {
364 words.insert(w);
365 }
366 }
367 }
368 words.into_iter().collect()
369 }
370
371 fn proximity_rerank(&self, results: &mut Vec<SearchResult>, query_terms: &[&str]) {
374 for r in results.iter_mut() {
375 let content = match self.get_content(&r.id) {
376 Ok(c) => c,
377 Err(_) => continue,
378 };
379 let lower = content.to_lowercase();
380
381 let mut positions: Vec<usize> = Vec::new();
383 for term in query_terms {
384 if let Some(pos) = lower.find(&term.to_lowercase()) {
385 positions.push(pos);
386 }
387 }
388
389 if positions.len() >= 2 {
390 positions.sort_unstable();
391 let span = positions.last().unwrap() - positions.first().unwrap();
393 let boost = 1.0 + 1.0 / (1.0 + span as f64 / 50.0);
396 r.score *= boost;
397 }
398 }
399
400 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
402 }
403}
404
405#[cfg(test)]
408mod tests {
409 use super::*;
410
411 fn make_search() -> AdvancedSearch {
412 AdvancedSearch::new().unwrap()
413 }
414
415 #[test]
416 fn test_index_and_bm25_search() {
417 let s = make_search();
418 s.index("d1", "the quick brown fox jumps over the lazy dog").unwrap();
419 s.index("d2", "a fast red car drives on the highway").unwrap();
420
421 let results = s.bm25_search("fox");
422 assert_eq!(results.len(), 1);
423 assert_eq!(results[0].1, "d1");
424 }
425
426 #[test]
427 fn test_trigram_search() {
428 let s = make_search();
429 s.index("d1", "authentication middleware handles tokens").unwrap();
430 s.index("d2", "database migration scripts for postgres").unwrap();
431
432 let results = s.trigram_search("auth");
433 assert_eq!(results.len(), 1);
434 assert_eq!(results[0].1, "d1");
435 }
436
437 #[test]
438 fn test_rrf_merge_both_lists() {
439 let s = make_search();
440 s.index("d1", "rust programming language systems").unwrap();
441 s.index("d2", "rust prevention coating for metal surfaces").unwrap();
442 s.index("d3", "programming in python is fun").unwrap();
443
444 let results = s.search("rust programming").unwrap();
447 assert!(!results.is_empty());
448 assert_eq!(results[0].id, "d1");
449 }
450
451 #[test]
452 fn test_rrf_docs_in_both_rank_higher() {
453 let s = make_search();
454 s.index("d1", "alpha beta gamma delta").unwrap();
456 s.index("d2", "alpha only here nothing else relevant").unwrap();
458
459 let bm25 = s.bm25_search("alpha");
460 let trigram = s.trigram_search("alpha");
461
462 let merged = s.reciprocal_rank_fusion(&bm25, &trigram);
464 assert!(merged.len() >= 1);
467
468 let in_bm25: std::collections::HashSet<_> = bm25.iter().map(|(_, id)| id.clone()).collect();
470 let in_trigram: std::collections::HashSet<_> = trigram.iter().map(|(_, id)| id.clone()).collect();
471 let in_both: std::collections::HashSet<_> = in_bm25.intersection(&in_trigram).cloned().collect();
472
473 if merged.len() >= 2 {
474 let top = &merged[0];
475 if in_both.contains(&top.id) {
476 }
478 }
479 }
480
481 #[test]
482 fn test_fuzzy_correction() {
483 let s = make_search();
484 s.index("d1", "authentication middleware").unwrap();
485 s.index("d2", "database migration").unwrap();
486
487 let corrected = s.fuzzy_correct("authentcation");
489 assert!(corrected.is_some());
490 let c = corrected.unwrap();
491 assert!(c.contains("authentication"), "corrected to: {}", c);
492 }
493
494 #[test]
495 fn test_fuzzy_search_end_to_end() {
496 let s = make_search();
497 s.index("d1", "authentication middleware handles tokens").unwrap();
498
499 let results = s.search("authentcation").unwrap();
501 assert!(!results.is_empty());
502 assert_eq!(results[0].id, "d1");
503 }
504
505 #[test]
506 fn test_proximity_reranking() {
507 let s = make_search();
508 s.index("d1", "the error handler catches all exceptions").unwrap();
510 s.index(
512 "d2",
513 "an error occurred in the system and after many lines of unrelated text the handler was invoked",
514 ).unwrap();
515
516 let results = s.search("error handler").unwrap();
517 assert!(results.len() >= 2);
518 assert_eq!(results[0].id, "d1");
520 }
521
522 #[test]
523 fn test_smart_snippet_extraction() {
524 let content = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. \
525 The authentication module verifies JWT tokens. \
526 Sed do eiusmod tempor incididunt ut labore.";
527 let snippet = extract_snippet(content, &["authentication"], 40);
528 assert!(snippet.contains("authentication"));
529 assert!(snippet.contains("…"));
531 }
532
533 #[test]
534 fn test_empty_query_returns_empty() {
535 let s = make_search();
536 s.index("d1", "some content").unwrap();
537 let results = s.search("").unwrap();
538 assert!(results.is_empty());
539 }
540
541 #[test]
542 fn test_no_results_returns_empty() {
543 let s = make_search();
544 s.index("d1", "hello world").unwrap();
545 let results = s.search("zzzznonexistent").unwrap();
546 assert!(results.is_empty());
547 }
548
549 #[test]
550 fn test_index_upsert() {
551 let s = make_search();
552 s.index("d1", "original content about cats").unwrap();
553 s.index("d1", "updated content about dogs").unwrap();
554
555 let results = s.search("dogs").unwrap();
556 assert_eq!(results.len(), 1);
557 assert_eq!(results[0].id, "d1");
558
559 let results = s.search("cats").unwrap();
560 assert!(results.is_empty());
561 }
562
563 #[test]
564 fn test_levenshtein_distance() {
565 assert_eq!(levenshtein("kitten", "sitting"), 3);
566 assert_eq!(levenshtein("hello", "hello"), 0);
567 assert_eq!(levenshtein("hello", "helo"), 1);
568 assert_eq!(levenshtein("", "abc"), 3);
569 assert_eq!(levenshtein("abc", ""), 3);
570 }
571
572 #[test]
573 fn test_snippet_at_start() {
574 let content = "authentication is important for security";
575 let snippet = extract_snippet(content, &["authentication"], 80);
576 assert!(snippet.contains("authentication"));
577 assert!(!snippet.starts_with('…'));
579 }
580
581 #[test]
582 fn test_multiple_documents_search() {
583 let s = make_search();
584 for i in 0..10 {
585 s.index(&format!("d{}", i), &format!("document number {} about testing", i))
586 .unwrap();
587 }
588 let results = s.search("testing").unwrap();
589 assert_eq!(results.len(), 10);
590 }
591
592 mod prop_tests {
593 use super::*;
594 use proptest::prelude::*;
595 use std::collections::{HashMap, HashSet};
596
597 proptest! {
606 #[test]
607 fn prop_rrf_merge_contains_all_unique_docs_and_both_rank_higher(
608 shared_count in 1..4usize,
611 a_only_count in 1..4usize,
612 b_only_count in 1..4usize,
613 ) {
614 let s = make_search();
615
616 let mut list_a: Vec<(f64, String)> = Vec::new();
619 let mut list_b: Vec<(f64, String)> = Vec::new();
620
621 for i in 0..shared_count {
623 let id = format!("shared_{}", i);
624 list_a.push((-(i as f64), id.clone()));
626 list_b.push((-(i as f64), id));
627 }
628
629 for i in 0..a_only_count {
631 let id = format!("a_only_{}", i);
632 list_a.push((-((shared_count + i) as f64), id));
633 }
634
635 for i in 0..b_only_count {
637 let id = format!("b_only_{}", i);
638 list_b.push((-((shared_count + i) as f64), id));
639 }
640
641 let merged = s.reciprocal_rank_fusion(&list_a, &list_b);
642
643 let all_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone())
645 .chain(list_b.iter().map(|(_, id)| id.clone()))
646 .collect();
647 let merged_ids: HashSet<String> = merged.iter().map(|r| r.id.clone()).collect();
648 prop_assert_eq!(
649 merged_ids, all_ids,
650 "Merged result set must contain all unique documents from both input lists"
651 );
652
653 let a_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone()).collect();
655 let b_ids: HashSet<String> = list_b.iter().map(|(_, id)| id.clone()).collect();
656 let in_both: HashSet<String> = a_ids.intersection(&b_ids).cloned().collect();
657 let in_one_only: HashSet<String> = a_ids.symmetric_difference(&b_ids).cloned().collect();
658
659 if !in_both.is_empty() && !in_one_only.is_empty() {
660 let scores: HashMap<String, f64> = merged.iter()
661 .map(|r| (r.id.clone(), r.score))
662 .collect();
663
664 let min_both_score = in_both.iter()
665 .filter_map(|id| scores.get(id))
666 .cloned()
667 .fold(f64::INFINITY, f64::min);
668
669 let max_one_score = in_one_only.iter()
670 .filter_map(|id| scores.get(id))
671 .cloned()
672 .fold(f64::NEG_INFINITY, f64::max);
673
674 prop_assert!(
675 min_both_score > max_one_score,
676 "Documents in both lists (min score {}) must score higher \
677 than documents in only one list (max score {})",
678 min_both_score, max_one_score
679 );
680 }
681 }
682 }
683 }
684}