synaptic_retrieval/
bm25.rs1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5
6use crate::{tokenize_to_vec, Document, Retriever};
7
8#[derive(Debug, Clone)]
13pub struct BM25Retriever {
14 documents: Vec<Document>,
15 doc_term_freqs: Vec<HashMap<String, usize>>,
17 doc_lengths: Vec<usize>,
19 avg_doc_length: f64,
21 doc_freq: HashMap<String, usize>,
23 k1: f64,
25 b: f64,
27}
28
29impl BM25Retriever {
30 pub fn new(documents: Vec<Document>) -> Self {
32 Self::with_params(documents, 1.5, 0.75)
33 }
34
35 pub fn with_params(documents: Vec<Document>, k1: f64, b: f64) -> Self {
37 let mut doc_term_freqs = Vec::with_capacity(documents.len());
38 let mut doc_lengths = Vec::with_capacity(documents.len());
39 let mut doc_freq: HashMap<String, usize> = HashMap::new();
40
41 for doc in &documents {
42 let tokens = tokenize_to_vec(&doc.content);
43 let mut term_freq: HashMap<String, usize> = HashMap::new();
44
45 for token in &tokens {
46 *term_freq.entry(token.clone()).or_insert(0) += 1;
47 }
48
49 for term in term_freq.keys() {
51 *doc_freq.entry(term.clone()).or_insert(0) += 1;
52 }
53
54 doc_term_freqs.push(term_freq);
55 doc_lengths.push(tokens.len());
56 }
57
58 let avg_doc_length = if documents.is_empty() {
59 0.0
60 } else {
61 doc_lengths.iter().sum::<usize>() as f64 / documents.len() as f64
62 };
63
64 Self {
65 documents,
66 doc_term_freqs,
67 doc_lengths,
68 avg_doc_length,
69 doc_freq,
70 k1,
71 b,
72 }
73 }
74
75 fn score(&self, doc_idx: usize, query_terms: &[String]) -> f64 {
77 let n = self.documents.len() as f64;
78 let doc_len = self.doc_lengths[doc_idx] as f64;
79 let term_freqs = &self.doc_term_freqs[doc_idx];
80
81 let mut score = 0.0;
82
83 for term in query_terms {
84 let tf = *term_freqs.get(term).unwrap_or(&0) as f64;
85 let df = *self.doc_freq.get(term).unwrap_or(&0) as f64;
86
87 if df == 0.0 || tf == 0.0 {
88 continue;
89 }
90
91 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
93
94 let numerator = tf * (self.k1 + 1.0);
96 let denominator =
97 tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length);
98
99 score += idf * numerator / denominator;
100 }
101
102 score
103 }
104}
105
106#[async_trait]
107impl Retriever for BM25Retriever {
108 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError> {
109 let query_terms = tokenize_to_vec(query);
110
111 if query_terms.is_empty() {
112 return Ok(vec![]);
113 }
114
115 let mut scored: Vec<(f64, usize)> = self
116 .documents
117 .iter()
118 .enumerate()
119 .map(|(idx, _)| (self.score(idx, &query_terms), idx))
120 .filter(|(score, _)| *score > 0.0)
121 .collect();
122
123 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
125
126 Ok(scored
127 .into_iter()
128 .take(top_k)
129 .map(|(_, idx)| self.documents[idx].clone())
130 .collect())
131 }
132}