Skip to main content

sqz_engine/
advanced_search.rs

1use rusqlite::{params, Connection};
2
3use crate::error::Result;
4
5/// A single search result with id, fused score, and a smart snippet.
6#[derive(Debug, Clone)]
7pub struct SearchResult {
8    pub id: String,
9    pub score: f64,
10    pub snippet: String,
11}
12
13/// Advanced search engine combining BM25 (porter stemming) and trigram
14/// substring search, merged via Reciprocal Rank Fusion.  Backed by an
15/// in-memory SQLite database with two FTS5 virtual tables.
16pub struct AdvancedSearch {
17    db: Connection,
18}
19
20// ── Schema ────────────────────────────────────────────────────────────────────
21
22const SCHEMA: &str = r#"
23CREATE TABLE IF NOT EXISTS docs (
24    id      TEXT PRIMARY KEY,
25    content TEXT NOT NULL
26);
27
28CREATE VIRTUAL TABLE IF NOT EXISTS docs_porter USING fts5(
29    id,
30    content,
31    content='docs',
32    content_rowid='rowid',
33    tokenize='porter ascii'
34);
35
36CREATE VIRTUAL TABLE IF NOT EXISTS docs_trigram USING fts5(
37    id,
38    content,
39    content='docs',
40    content_rowid='rowid',
41    tokenize='trigram'
42);
43
44-- Keep FTS tables in sync with the docs table.
45CREATE TRIGGER IF NOT EXISTS docs_ai AFTER INSERT ON docs BEGIN
46    INSERT INTO docs_porter(rowid, id, content)
47    VALUES (new.rowid, new.id, new.content);
48    INSERT INTO docs_trigram(rowid, id, content)
49    VALUES (new.rowid, new.id, new.content);
50END;
51
52CREATE TRIGGER IF NOT EXISTS docs_ad AFTER DELETE ON docs BEGIN
53    INSERT INTO docs_porter(docs_porter, rowid, id, content)
54    VALUES ('delete', old.rowid, old.id, old.content);
55    INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
56    VALUES ('delete', old.rowid, old.id, old.content);
57END;
58
59CREATE TRIGGER IF NOT EXISTS docs_au AFTER UPDATE ON docs BEGIN
60    INSERT INTO docs_porter(docs_porter, rowid, id, content)
61    VALUES ('delete', old.rowid, old.id, old.content);
62    INSERT INTO docs_trigram(docs_trigram, rowid, id, content)
63    VALUES ('delete', old.rowid, old.id, old.content);
64    INSERT INTO docs_porter(rowid, id, content)
65    VALUES (new.rowid, new.id, new.content);
66    INSERT INTO docs_trigram(rowid, id, content)
67    VALUES (new.rowid, new.id, new.content);
68END;
69"#;
70
71// ── RRF constant ──────────────────────────────────────────────────────────────
72
73/// Reciprocal Rank Fusion smoothing constant (standard value from the
74/// original Cormack, Clarke & Buettcher paper).
75const RRF_K: f64 = 60.0;
76
77// ── Helpers ───────────────────────────────────────────────────────────────────
78
79/// Compute Levenshtein edit distance between two strings.
80fn levenshtein(a: &str, b: &str) -> usize {
81    let a_chars: Vec<char> = a.chars().collect();
82    let b_chars: Vec<char> = b.chars().collect();
83    let m = a_chars.len();
84    let n = b_chars.len();
85
86    let mut prev = (0..=n).collect::<Vec<_>>();
87    let mut curr = vec![0usize; n + 1];
88
89    for i in 1..=m {
90        curr[0] = i;
91        for j in 1..=n {
92            let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
93            curr[j] = (prev[j] + 1)
94                .min(curr[j - 1] + 1)
95                .min(prev[j - 1] + cost);
96        }
97        std::mem::swap(&mut prev, &mut curr);
98    }
99    prev[n]
100}
101
102/// Extract a smart snippet: a window of text around the first occurrence
103/// of any query term, with `…` ellipsis markers when truncated.
104fn extract_snippet(content: &str, query_terms: &[&str], window: usize) -> String {
105    let lower = content.to_lowercase();
106    // Find the earliest match position across all terms.
107    let mut best_pos: Option<usize> = None;
108    for term in query_terms {
109        if let Some(pos) = lower.find(&term.to_lowercase()) {
110            best_pos = Some(match best_pos {
111                Some(bp) => bp.min(pos),
112                None => pos,
113            });
114        }
115    }
116
117    let pos = match best_pos {
118        Some(p) => p,
119        None => 0,
120    };
121
122    let start = pos.saturating_sub(window);
123    let end = (pos + window).min(content.len());
124
125    // Snap to char boundaries (stable alternative to floor/ceil_char_boundary).
126    let start = {
127        let mut i = start;
128        while i > 0 && !content.is_char_boundary(i) { i -= 1; }
129        i
130    };
131    let end = {
132        let mut i = end;
133        while i < content.len() && !content.is_char_boundary(i) { i += 1; }
134        i
135    };
136
137    let mut snippet = String::new();
138    if start > 0 {
139        snippet.push_str("…");
140    }
141    snippet.push_str(&content[start..end]);
142    if end < content.len() {
143        snippet.push_str("…");
144    }
145    snippet
146}
147
148// ── AdvancedSearch impl ───────────────────────────────────────────────────────
149
150impl AdvancedSearch {
151    /// Create a new `AdvancedSearch` backed by an in-memory SQLite database.
152    pub fn new() -> Result<Self> {
153        let db = Connection::open_in_memory()?;
154        db.execute_batch(SCHEMA)?;
155        Ok(Self { db })
156    }
157
158    /// Index a document.  If a document with the same `id` already exists it
159    /// is replaced.
160    pub fn index(&self, id: &str, content: &str) -> Result<()> {
161        self.db.execute(
162            "INSERT INTO docs (id, content) VALUES (?1, ?2)
163             ON CONFLICT(id) DO UPDATE SET content = excluded.content",
164            params![id, content],
165        )?;
166        Ok(())
167    }
168
169    /// Run an advanced search combining BM25, trigram, RRF, fuzzy correction,
170    /// proximity reranking, and smart snippet extraction.
171    pub fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
172        let query = query.trim();
173        if query.is_empty() {
174            return Ok(Vec::new());
175        }
176
177        let terms: Vec<&str> = query.split_whitespace().collect();
178
179        // 1. BM25 search (porter stemming).
180        let bm25 = self.bm25_search(query);
181
182        // 2. Trigram substring search.
183        let trigram = self.trigram_search(query);
184
185        // 3. Merge via Reciprocal Rank Fusion.
186        let mut results = self.reciprocal_rank_fusion(&bm25, &trigram);
187
188        // 4. If no results, try fuzzy correction (Levenshtein ≤ 2).
189        if results.is_empty() {
190            if let Some(corrected) = self.fuzzy_correct(query) {
191                let bm25_c = self.bm25_search(&corrected);
192                let trigram_c = self.trigram_search(&corrected);
193                results = self.reciprocal_rank_fusion(&bm25_c, &trigram_c);
194            }
195        }
196
197        // 5. Proximity reranking for multi-term queries.
198        if terms.len() > 1 {
199            self.proximity_rerank(&mut results, &terms);
200        }
201
202        // 6. Smart snippet extraction.
203        for r in &mut results {
204            if let Ok(content) = self.get_content(&r.id) {
205                r.snippet = extract_snippet(&content, &terms, 80);
206            }
207        }
208
209        Ok(results)
210    }
211
212    // ── Internal helpers ──────────────────────────────────────────────────────
213
214    fn get_content(&self, id: &str) -> Result<String> {
215        let content: String = self.db.query_row(
216            "SELECT content FROM docs WHERE id = ?1",
217            params![id],
218            |row| row.get(0),
219        )?;
220        Ok(content)
221    }
222
223    /// BM25 search on the porter-stemmed FTS5 table.
224    /// Returns `(bm25_score, doc_id)` pairs ordered by relevance.
225    fn bm25_search(&self, query: &str) -> Vec<(f64, String)> {
226        let mut stmt = match self.db.prepare(
227            "SELECT d.id, bm25(docs_porter) AS score
228             FROM docs_porter p
229             JOIN docs d ON d.rowid = p.rowid
230             WHERE docs_porter MATCH ?1
231             ORDER BY score",
232        ) {
233            Ok(s) => s,
234            Err(_) => return Vec::new(),
235        };
236
237        let rows = match stmt.query_map(params![query], |row| {
238            Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
239        }) {
240            Ok(r) => r,
241            Err(_) => return Vec::new(),
242        };
243
244        rows.filter_map(|r| r.ok())
245            .map(|(id, score)| (score, id))
246            .collect()
247    }
248
249    /// Trigram substring search on the trigram FTS5 table.
250    /// Returns `(bm25_score, doc_id)` pairs ordered by relevance.
251    fn trigram_search(&self, query: &str) -> Vec<(f64, String)> {
252        // Trigram tokenizer requires the query to be at least 3 chars.
253        if query.len() < 3 {
254            return Vec::new();
255        }
256        let mut stmt = match self.db.prepare(
257            "SELECT d.id, bm25(docs_trigram) AS score
258             FROM docs_trigram t
259             JOIN docs d ON d.rowid = t.rowid
260             WHERE docs_trigram MATCH ?1
261             ORDER BY score",
262        ) {
263            Ok(s) => s,
264            Err(_) => return Vec::new(),
265        };
266
267        let rows = match stmt.query_map(params![query], |row| {
268            Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
269        }) {
270            Ok(r) => r,
271            Err(_) => return Vec::new(),
272        };
273
274        rows.filter_map(|r| r.ok())
275            .map(|(id, score)| (score, id))
276            .collect()
277    }
278
279    /// Merge two ranked lists using Reciprocal Rank Fusion.
280    ///
281    /// Documents appearing in both lists receive a higher fused score than
282    /// documents appearing in only one.
283    fn reciprocal_rank_fusion(
284        &self,
285        a: &[(f64, String)],
286        b: &[(f64, String)],
287    ) -> Vec<SearchResult> {
288        use std::collections::HashMap;
289
290        let mut scores: HashMap<String, f64> = HashMap::new();
291
292        for (rank, (_score, id)) in a.iter().enumerate() {
293            *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
294        }
295        for (rank, (_score, id)) in b.iter().enumerate() {
296            *scores.entry(id.clone()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
297        }
298
299        let mut results: Vec<SearchResult> = scores
300            .into_iter()
301            .map(|(id, score)| SearchResult {
302                id,
303                score,
304                snippet: String::new(),
305            })
306            .collect();
307
308        // Sort descending by fused score.
309        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
310        results
311    }
312
313    /// Attempt fuzzy correction for the query.  Collects all known terms from
314    /// the docs table and finds the closest match within Levenshtein distance 2.
315    fn fuzzy_correct(&self, query: &str) -> Option<String> {
316        // Build a vocabulary of unique words from indexed documents.
317        let vocab = self.vocabulary();
318        if vocab.is_empty() {
319            return None;
320        }
321
322        let terms: Vec<&str> = query.split_whitespace().collect();
323        let mut corrected_terms: Vec<String> = Vec::new();
324        let mut any_corrected = false;
325
326        for term in &terms {
327            let lower = term.to_lowercase();
328            let mut best: Option<(usize, String)> = None;
329            for word in &vocab {
330                let dist = levenshtein(&lower, word);
331                if dist > 0 && dist <= 2 {
332                    if best.as_ref().map_or(true, |(d, _)| dist < *d) {
333                        best = Some((dist, word.clone()));
334                    }
335                }
336            }
337            if let Some((_dist, correction)) = best {
338                corrected_terms.push(correction);
339                any_corrected = true;
340            } else {
341                corrected_terms.push(lower);
342            }
343        }
344
345        if any_corrected {
346            Some(corrected_terms.join(" "))
347        } else {
348            None
349        }
350    }
351
352    /// Collect unique lowercase words from all indexed documents.
353    fn vocabulary(&self) -> Vec<String> {
354        let mut stmt = match self.db.prepare("SELECT content FROM docs") {
355            Ok(s) => s,
356            Err(_) => return Vec::new(),
357        };
358        let rows = match stmt.query_map([], |row| row.get::<_, String>(0)) {
359            Ok(r) => r,
360            Err(_) => return Vec::new(),
361        };
362
363        let mut words = std::collections::HashSet::new();
364        for row in rows.flatten() {
365            for word in row.split_whitespace() {
366                let w: String = word
367                    .chars()
368                    .filter(|c| c.is_alphanumeric())
369                    .collect::<String>()
370                    .to_lowercase();
371                if w.len() >= 2 {
372                    words.insert(w);
373                }
374            }
375        }
376        words.into_iter().collect()
377    }
378
379    /// Proximity reranking: boost results where query terms appear close
380    /// together in the document content.
381    fn proximity_rerank(&self, results: &mut Vec<SearchResult>, query_terms: &[&str]) {
382        for r in results.iter_mut() {
383            let content = match self.get_content(&r.id) {
384                Ok(c) => c,
385                Err(_) => continue,
386            };
387            let lower = content.to_lowercase();
388
389            // Find positions of each query term.
390            let mut positions: Vec<usize> = Vec::new();
391            for term in query_terms {
392                if let Some(pos) = lower.find(&term.to_lowercase()) {
393                    positions.push(pos);
394                }
395            }
396
397            if positions.len() >= 2 {
398                positions.sort_unstable();
399                // Compute the span (distance between first and last term).
400                let span = positions.last().unwrap() - positions.first().unwrap();
401                // Boost: closer terms → higher boost.  A span of 0 gives max
402                // boost of 2×; very distant terms give ~1× (no boost).
403                let boost = 1.0 + 1.0 / (1.0 + span as f64 / 50.0);
404                r.score *= boost;
405            }
406        }
407
408        // Re-sort after boosting.
409        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
410    }
411}
412
413// ── Tests ─────────────────────────────────────────────────────────────────────
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    fn make_search() -> AdvancedSearch {
420        AdvancedSearch::new().unwrap()
421    }
422
423    #[test]
424    fn test_index_and_bm25_search() {
425        let s = make_search();
426        s.index("d1", "the quick brown fox jumps over the lazy dog").unwrap();
427        s.index("d2", "a fast red car drives on the highway").unwrap();
428
429        let results = s.bm25_search("fox");
430        assert_eq!(results.len(), 1);
431        assert_eq!(results[0].1, "d1");
432    }
433
434    #[test]
435    fn test_trigram_search() {
436        let s = make_search();
437        s.index("d1", "authentication middleware handles tokens").unwrap();
438        s.index("d2", "database migration scripts for postgres").unwrap();
439
440        let results = s.trigram_search("auth");
441        assert_eq!(results.len(), 1);
442        assert_eq!(results[0].1, "d1");
443    }
444
445    #[test]
446    fn test_rrf_merge_both_lists() {
447        let s = make_search();
448        s.index("d1", "rust programming language systems").unwrap();
449        s.index("d2", "rust prevention coating for metal surfaces").unwrap();
450        s.index("d3", "programming in python is fun").unwrap();
451
452        // "rust programming" should match d1 in both porter and trigram,
453        // giving it a higher RRF score than d2 or d3.
454        let results = s.search("rust programming").unwrap();
455        assert!(!results.is_empty());
456        assert_eq!(results[0].id, "d1");
457    }
458
459    #[test]
460    fn test_rrf_docs_in_both_rank_higher() {
461        let s = make_search();
462        // d1 contains both "alpha" and "beta" — will appear in both BM25 and trigram.
463        s.index("d1", "alpha beta gamma delta").unwrap();
464        // d2 contains only "alpha".
465        s.index("d2", "alpha only here nothing else relevant").unwrap();
466
467        let bm25 = s.bm25_search("alpha");
468        let trigram = s.trigram_search("alpha");
469
470        // Both should find d1 and d2.
471        let merged = s.reciprocal_rank_fusion(&bm25, &trigram);
472        // d1 and d2 both appear in both lists, but let's just verify the
473        // merge produces results from both lists.
474        assert!(merged.len() >= 1);
475
476        // Docs appearing in both lists should have higher scores.
477        let in_bm25: std::collections::HashSet<_> = bm25.iter().map(|(_, id)| id.clone()).collect();
478        let in_trigram: std::collections::HashSet<_> = trigram.iter().map(|(_, id)| id.clone()).collect();
479        let in_both: std::collections::HashSet<_> = in_bm25.intersection(&in_trigram).cloned().collect();
480
481        if merged.len() >= 2 {
482            let top = &merged[0];
483            if in_both.contains(&top.id) {
484                // Good — doc in both lists is ranked first.
485            }
486        }
487    }
488
489    #[test]
490    fn test_fuzzy_correction() {
491        let s = make_search();
492        s.index("d1", "authentication middleware").unwrap();
493        s.index("d2", "database migration").unwrap();
494
495        // "authentcation" is a typo (missing 'i'), Levenshtein distance 1.
496        let corrected = s.fuzzy_correct("authentcation");
497        assert!(corrected.is_some());
498        let c = corrected.unwrap();
499        assert!(c.contains("authentication"), "corrected to: {}", c);
500    }
501
502    #[test]
503    fn test_fuzzy_search_end_to_end() {
504        let s = make_search();
505        s.index("d1", "authentication middleware handles tokens").unwrap();
506
507        // Typo query — should still find d1 via fuzzy correction.
508        let results = s.search("authentcation").unwrap();
509        assert!(!results.is_empty());
510        assert_eq!(results[0].id, "d1");
511    }
512
513    #[test]
514    fn test_proximity_reranking() {
515        let s = make_search();
516        // d1: terms "error" and "handler" are close together.
517        s.index("d1", "the error handler catches all exceptions").unwrap();
518        // d2: terms "error" and "handler" are far apart.
519        s.index(
520            "d2",
521            "an error occurred in the system and after many lines of unrelated text the handler was invoked",
522        ).unwrap();
523
524        let results = s.search("error handler").unwrap();
525        assert!(results.len() >= 2);
526        // d1 should rank higher due to proximity boost.
527        assert_eq!(results[0].id, "d1");
528    }
529
530    #[test]
531    fn test_smart_snippet_extraction() {
532        let content = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. \
533                        The authentication module verifies JWT tokens. \
534                        Sed do eiusmod tempor incididunt ut labore.";
535        let snippet = extract_snippet(content, &["authentication"], 40);
536        assert!(snippet.contains("authentication"));
537        // Should have ellipsis since it's in the middle.
538        assert!(snippet.contains("…"));
539    }
540
541    #[test]
542    fn test_empty_query_returns_empty() {
543        let s = make_search();
544        s.index("d1", "some content").unwrap();
545        let results = s.search("").unwrap();
546        assert!(results.is_empty());
547    }
548
549    #[test]
550    fn test_no_results_returns_empty() {
551        let s = make_search();
552        s.index("d1", "hello world").unwrap();
553        let results = s.search("zzzznonexistent").unwrap();
554        assert!(results.is_empty());
555    }
556
557    #[test]
558    fn test_index_upsert() {
559        let s = make_search();
560        s.index("d1", "original content about cats").unwrap();
561        s.index("d1", "updated content about dogs").unwrap();
562
563        let results = s.search("dogs").unwrap();
564        assert_eq!(results.len(), 1);
565        assert_eq!(results[0].id, "d1");
566
567        let results = s.search("cats").unwrap();
568        assert!(results.is_empty());
569    }
570
571    #[test]
572    fn test_levenshtein_distance() {
573        assert_eq!(levenshtein("kitten", "sitting"), 3);
574        assert_eq!(levenshtein("hello", "hello"), 0);
575        assert_eq!(levenshtein("hello", "helo"), 1);
576        assert_eq!(levenshtein("", "abc"), 3);
577        assert_eq!(levenshtein("abc", ""), 3);
578    }
579
580    #[test]
581    fn test_snippet_at_start() {
582        let content = "authentication is important for security";
583        let snippet = extract_snippet(content, &["authentication"], 80);
584        assert!(snippet.contains("authentication"));
585        // Should not have leading ellipsis since match is at start.
586        assert!(!snippet.starts_with('…'));
587    }
588
589    #[test]
590    fn test_multiple_documents_search() {
591        let s = make_search();
592        for i in 0..10 {
593            s.index(&format!("d{}", i), &format!("document number {} about testing", i))
594                .unwrap();
595        }
596        let results = s.search("testing").unwrap();
597        assert_eq!(results.len(), 10);
598    }
599
600    mod prop_tests {
601        use super::*;
602        use proptest::prelude::*;
603        use std::collections::{HashMap, HashSet};
604
605        // **Validates: Requirements 41.1, 41.2**
606        //
607        // Property 41: Advanced search RRF merges correctly
608        //
609        // For any two ranked lists, the Reciprocal Rank Fusion merge SHALL:
610        // 1. Produce a result set containing all unique documents from both inputs.
611        // 2. Assign a higher RRF score to documents appearing in both lists than
612        //    to documents appearing in only one list (at comparable rank positions).
613        proptest! {
614            #[test]
615            fn prop_rrf_merge_contains_all_unique_docs_and_both_rank_higher(
616                // Generate 2-6 unique doc IDs for list A, 2-6 for list B,
617                // with at least some overlap guaranteed by construction.
618                shared_count in 1..4usize,
619                a_only_count in 1..4usize,
620                b_only_count in 1..4usize,
621            ) {
622                let s = make_search();
623
624                // Build document sets: shared docs appear in both lists,
625                // a_only docs appear only in list A, b_only only in list B.
626                let mut list_a: Vec<(f64, String)> = Vec::new();
627                let mut list_b: Vec<(f64, String)> = Vec::new();
628
629                // Shared documents — present in both lists.
630                for i in 0..shared_count {
631                    let id = format!("shared_{}", i);
632                    // Use descending scores so rank = index.
633                    list_a.push((-(i as f64), id.clone()));
634                    list_b.push((-(i as f64), id));
635                }
636
637                // A-only documents.
638                for i in 0..a_only_count {
639                    let id = format!("a_only_{}", i);
640                    list_a.push((-((shared_count + i) as f64), id));
641                }
642
643                // B-only documents.
644                for i in 0..b_only_count {
645                    let id = format!("b_only_{}", i);
646                    list_b.push((-((shared_count + i) as f64), id));
647                }
648
649                let merged = s.reciprocal_rank_fusion(&list_a, &list_b);
650
651                // ── Property 1: merged contains all unique docs from both lists ──
652                let all_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone())
653                    .chain(list_b.iter().map(|(_, id)| id.clone()))
654                    .collect();
655                let merged_ids: HashSet<String> = merged.iter().map(|r| r.id.clone()).collect();
656                prop_assert_eq!(
657                    merged_ids, all_ids,
658                    "Merged result set must contain all unique documents from both input lists"
659                );
660
661                // ── Property 2: docs in both lists score higher than docs in only one ──
662                let a_ids: HashSet<String> = list_a.iter().map(|(_, id)| id.clone()).collect();
663                let b_ids: HashSet<String> = list_b.iter().map(|(_, id)| id.clone()).collect();
664                let in_both: HashSet<String> = a_ids.intersection(&b_ids).cloned().collect();
665                let in_one_only: HashSet<String> = a_ids.symmetric_difference(&b_ids).cloned().collect();
666
667                if !in_both.is_empty() && !in_one_only.is_empty() {
668                    let scores: HashMap<String, f64> = merged.iter()
669                        .map(|r| (r.id.clone(), r.score))
670                        .collect();
671
672                    let min_both_score = in_both.iter()
673                        .filter_map(|id| scores.get(id))
674                        .cloned()
675                        .fold(f64::INFINITY, f64::min);
676
677                    let max_one_score = in_one_only.iter()
678                        .filter_map(|id| scores.get(id))
679                        .cloned()
680                        .fold(f64::NEG_INFINITY, f64::max);
681
682                    prop_assert!(
683                        min_both_score > max_one_score,
684                        "Documents in both lists (min score {}) must score higher \
685                         than documents in only one list (max score {})",
686                        min_both_score, max_one_score
687                    );
688                }
689            }
690        }
691    }
692}