1use anyhow::{Result, anyhow};
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tantivy::{
16 Index, IndexReader, TantivyDocument,
17 collector::TopDocs,
18 query::{AllQuery, QueryParser},
19 schema::{
20 Field, IndexRecordOption, STORED, STRING, Schema, TextFieldIndexing, TextOptions, Value,
21 },
22 tokenizer::{Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer, TextAnalyzer},
23};
24use tokio::sync::Mutex;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum StemLanguage {
30 #[default]
31 English,
32 German,
33 French,
34 Spanish,
35 Italian,
36 Portuguese,
37 Russian,
38 None,
40}
41
42impl StemLanguage {
43 fn to_tantivy_language(self) -> Option<Language> {
44 match self {
45 StemLanguage::English => Some(Language::English),
46 StemLanguage::German => Some(Language::German),
47 StemLanguage::French => Some(Language::French),
48 StemLanguage::Spanish => Some(Language::Spanish),
49 StemLanguage::Italian => Some(Language::Italian),
50 StemLanguage::Portuguese => Some(Language::Portuguese),
51 StemLanguage::Russian => Some(Language::Russian),
52 StemLanguage::None => None,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct BM25Config {
60 #[serde(default = "default_bm25_path")]
62 pub index_path: String,
63 #[serde(default = "default_heap_size")]
65 pub writer_heap_size: usize,
66 #[serde(default = "default_true")]
68 pub enable_stemming: bool,
69 #[serde(default)]
71 pub language: StemLanguage,
72 #[serde(default)]
75 pub read_only: bool,
76}
77
78fn default_bm25_path() -> String {
79 "~/.rmcp-servers/rust-memex/bm25".to_string()
80}
81
82fn default_heap_size() -> usize {
83 50_000_000
84}
85
86fn default_true() -> bool {
87 true
88}
89
90impl Default for BM25Config {
91 fn default() -> Self {
92 Self {
93 index_path: default_bm25_path(),
94 writer_heap_size: default_heap_size(),
95 enable_stemming: true,
96 language: StemLanguage::English,
97 read_only: false,
98 }
99 }
100}
101
102impl BM25Config {
103 pub fn multilingual() -> Self {
105 Self {
106 language: StemLanguage::None,
107 enable_stemming: false,
108 ..Self::default()
109 }
110 }
111
112 pub fn read_only() -> Self {
114 Self {
115 read_only: true,
116 ..Self::default()
117 }
118 }
119
120 pub fn with_path(mut self, path: impl Into<String>) -> Self {
121 self.index_path = path.into();
122 self
123 }
124
125 pub fn with_read_only(mut self, read_only: bool) -> Self {
126 self.read_only = read_only;
127 self
128 }
129}
130
131pub struct BM25Index {
137 index: Index,
138 reader: IndexReader,
139 content_field: Field,
140 id_field: Field,
141 namespace_field: Field,
142 writer_heap_size: usize,
144 read_only: bool,
146 write_lock: Arc<Mutex<()>>,
148 index_path: PathBuf,
150}
151
152impl BM25Index {
153 pub fn new(config: &BM25Config) -> Result<Self> {
155 let path = crate::path_utils::sanitize_new_path(&config.index_path)?;
156
157 if !path.exists() {
159 std::fs::create_dir_all(&path)?;
160 }
161
162 let mut schema_builder = Schema::builder();
164
165 let text_options = TextOptions::default()
167 .set_indexing_options(
168 TextFieldIndexing::default()
169 .set_tokenizer("custom_tokenizer")
170 .set_index_option(IndexRecordOption::WithFreqsAndPositions),
171 )
172 .set_stored();
173
174 let content_field = schema_builder.add_text_field("content", text_options);
175 let id_field = schema_builder.add_text_field("id", STRING | STORED);
176 let namespace_field = schema_builder.add_text_field("namespace", STRING | STORED);
177
178 let schema = schema_builder.build();
179
180 let index = if path.join("meta.json").exists() {
182 Index::open_in_dir(&path)?
183 } else {
184 Index::create_in_dir(&path, schema.clone())?
185 };
186
187 let tokenizer = if config.enable_stemming {
189 if let Some(lang) = config.language.to_tantivy_language() {
190 TextAnalyzer::builder(SimpleTokenizer::default())
191 .filter(RemoveLongFilter::limit(40))
192 .filter(LowerCaser)
193 .filter(Stemmer::new(lang))
194 .build()
195 } else {
196 TextAnalyzer::builder(SimpleTokenizer::default())
198 .filter(RemoveLongFilter::limit(40))
199 .filter(LowerCaser)
200 .build()
201 }
202 } else {
203 TextAnalyzer::builder(SimpleTokenizer::default())
204 .filter(RemoveLongFilter::limit(40))
205 .filter(LowerCaser)
206 .build()
207 };
208
209 index.tokenizers().register("custom_tokenizer", tokenizer);
210
211 let reader = index.reader()?;
212
213 if config.read_only {
214 tracing::info!("BM25 index opened in READ-ONLY mode");
215 } else {
216 tracing::debug!("BM25 index opened (on-demand lock acquisition for writes)");
217 }
218
219 Ok(Self {
220 index,
221 reader,
222 content_field,
223 id_field,
224 namespace_field,
225 writer_heap_size: config.writer_heap_size,
226 read_only: config.read_only,
227 write_lock: Arc::new(Mutex::new(())),
228 index_path: path,
229 })
230 }
231
232 pub fn is_read_only(&self) -> bool {
234 self.read_only
235 }
236
237 async fn with_writer<F, T>(&self, operation: F) -> Result<T>
242 where
243 F: FnOnce(&mut tantivy::IndexWriter) -> Result<T>,
244 {
245 if self.read_only {
246 return Err(anyhow!("Cannot write: BM25 index is in read-only mode"));
247 }
248
249 let _guard = self.write_lock.lock().await;
251
252 const MAX_RETRIES: u32 = 5;
254 const INITIAL_BACKOFF_MS: u64 = 50;
255 const MAX_BACKOFF_MS: u64 = 2000;
256
257 let mut attempt = 0;
258 let mut backoff_ms = INITIAL_BACKOFF_MS;
259
260 let mut writer = loop {
261 match self.index.writer(self.writer_heap_size) {
262 Ok(w) => break w,
263 Err(e) => {
264 let is_lock_busy = e.to_string().contains("LockBusy");
265
266 if is_lock_busy && attempt < MAX_RETRIES {
267 attempt += 1;
268 tracing::debug!(
269 "BM25 lock busy, retry {}/{} in {}ms. Path: {:?}",
270 attempt,
271 MAX_RETRIES,
272 backoff_ms,
273 self.index_path
274 );
275 tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
276 backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
277 } else if is_lock_busy {
278 return Err(anyhow!(
279 "BM25 index locked after {} retries. Path: {:?}. \
280 Multiple processes writing simultaneously - try again.",
281 MAX_RETRIES,
282 self.index_path
283 ));
284 } else {
285 return Err(anyhow!("Failed to acquire BM25 writer: {}", e));
286 }
287 }
288 }
289 };
290
291 let result = operation(&mut writer)?;
293
294 writer.commit()?;
296
297 drop(writer);
299
300 self.reader.reload()?;
302
303 Ok(result)
304 }
305
306 pub async fn add_documents(&self, docs: &[(String, String, String)]) -> Result<()> {
316 let content_field = self.content_field;
317 let id_field = self.id_field;
318 let namespace_field = self.namespace_field;
319 let doc_count = docs.len();
320
321 let docs = docs.to_vec();
323
324 self.with_writer(move |writer| {
325 for (id, namespace, content) in &docs {
326 let mut doc = TantivyDocument::new();
327 doc.add_text(content_field, content);
328 doc.add_text(id_field, id);
329 doc.add_text(namespace_field, namespace);
330 writer.add_document(doc)?;
331 }
332 Ok(())
333 })
334 .await?;
335
336 tracing::debug!("Added {} documents to BM25 index", doc_count);
337 Ok(())
338 }
339
340 pub fn search(
350 &self,
351 query: &str,
352 namespace: Option<&str>,
353 limit: usize,
354 ) -> Result<Vec<(String, String, f32)>> {
355 let searcher = self.reader.searcher();
356
357 let query_parser = QueryParser::for_index(&self.index, vec![self.content_field]);
359
360 let escaped_query = Self::escape_query(query);
362 let parsed_query = query_parser
363 .parse_query(&escaped_query)
364 .map_err(|e| anyhow!("Query parse error: {}", e))?;
365
366 let top_docs = searcher.search(&parsed_query, &TopDocs::with_limit(limit * 2))?;
368
369 let mut results = Vec::with_capacity(limit);
370
371 for (score, doc_address) in top_docs {
372 let doc: TantivyDocument = searcher.doc(doc_address)?;
373
374 let id = doc
376 .get_first(self.id_field)
377 .and_then(|v| Value::as_str(&v).map(|s| s.to_string()))
378 .ok_or_else(|| anyhow!("Document missing ID field"))?;
379 let doc_namespace = doc
380 .get_first(self.namespace_field)
381 .and_then(|v| Value::as_str(&v).map(|s| s.to_string()))
382 .ok_or_else(|| anyhow!("Document missing namespace field"))?;
383
384 if let Some(ns) = namespace
386 && doc_namespace != ns
387 {
388 continue;
389 }
390
391 results.push((id, doc_namespace, score));
392
393 if results.len() >= limit {
394 break;
395 }
396 }
397
398 tracing::debug!("BM25 search '{}' returned {} results", query, results.len());
399
400 Ok(results)
401 }
402
403 pub async fn delete_documents(&self, ids: &[String]) -> Result<usize> {
410 let id_field = self.id_field;
411 let ids = ids.to_vec();
412 let count = ids.len();
413
414 self.with_writer(move |writer| {
415 for id in &ids {
416 let term = tantivy::Term::from_field_text(id_field, id);
417 writer.delete_term(term);
418 }
419 Ok(count)
420 })
421 .await
422 }
423
424 pub async fn delete_namespace_term(&self, namespace: &str) -> Result<usize> {
431 let namespace_field = self.namespace_field;
432 let namespace_owned = namespace.to_string();
433 let namespace_log = namespace.to_string();
434
435 self.with_writer(move |writer| {
436 let term = tantivy::Term::from_field_text(namespace_field, &namespace_owned);
437 writer.delete_term(term);
438 Ok(1) })
440 .await?;
441
442 tracing::info!("Purged namespace '{}' from BM25 index", namespace_log);
443 Ok(1)
444 }
445
446 fn escape_query(query: &str) -> String {
448 let special_chars = [
450 '+', '-', '&', '|', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':',
451 '\\', '/',
452 ];
453
454 let mut escaped = String::with_capacity(query.len() * 2);
455 for c in query.chars() {
456 if special_chars.contains(&c) {
457 escaped.push('\\');
458 }
459 escaped.push(c);
460 }
461 escaped
462 }
463
464 pub fn doc_count(&self) -> u64 {
466 let searcher = self.reader.searcher();
467 searcher.num_docs()
468 }
469
470 pub fn document_keys(&self, namespace: Option<&str>) -> Result<HashSet<(String, String)>> {
475 let searcher = self.reader.searcher();
476 let total = usize::try_from(searcher.num_docs()).unwrap_or(usize::MAX);
477 if total == 0 {
478 return Ok(HashSet::new());
479 }
480
481 let all_query = AllQuery;
482 let top_docs = searcher.search(&all_query, &TopDocs::with_limit(total))?;
483 let mut keys = HashSet::with_capacity(total);
484
485 for (_score, doc_address) in top_docs {
486 let doc: TantivyDocument = searcher.doc(doc_address)?;
487 let id = doc
488 .get_first(self.id_field)
489 .and_then(|value| Value::as_str(&value).map(|value| value.to_string()))
490 .ok_or_else(|| anyhow!("Document missing ID field"))?;
491 let doc_namespace = doc
492 .get_first(self.namespace_field)
493 .and_then(|value| Value::as_str(&value).map(|value| value.to_string()))
494 .ok_or_else(|| anyhow!("Document missing namespace field"))?;
495
496 if namespace.is_none_or(|expected| expected == doc_namespace) {
497 keys.insert((doc_namespace, id));
498 }
499 }
500
501 Ok(keys)
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use tempfile::TempDir;
509
510 #[tokio::test]
511 async fn test_bm25_basic() {
512 let temp_dir = TempDir::new().unwrap();
513 let config = BM25Config::default().with_path(temp_dir.path().to_str().unwrap());
514
515 let index = BM25Index::new(&config).unwrap();
516
517 let docs = vec![
519 (
520 "doc1".to_string(),
521 "test".to_string(),
522 "The quick brown fox jumps over the lazy dog".to_string(),
523 ),
524 (
525 "doc2".to_string(),
526 "test".to_string(),
527 "A quick brown dog runs in the park".to_string(),
528 ),
529 (
530 "doc3".to_string(),
531 "test".to_string(),
532 "The lazy cat sleeps all day".to_string(),
533 ),
534 ];
535
536 index.add_documents(&docs).await.unwrap();
537
538 let results = index.search("quick brown", None, 10).unwrap();
540
541 assert_eq!(results.len(), 2);
542 let ids: Vec<&str> = results.iter().map(|(id, _, _)| id.as_str()).collect();
544 assert!(ids.contains(&"doc1"));
545 assert!(ids.contains(&"doc2"));
546 }
547
548 #[tokio::test]
549 async fn test_bm25_namespace_filter() {
550 let temp_dir = TempDir::new().unwrap();
551 let config = BM25Config::default().with_path(temp_dir.path().to_str().unwrap());
552
553 let index = BM25Index::new(&config).unwrap();
554
555 let docs = vec![
556 (
557 "doc1".to_string(),
558 "ns1".to_string(),
559 "hello world".to_string(),
560 ),
561 (
562 "doc2".to_string(),
563 "ns2".to_string(),
564 "hello universe".to_string(),
565 ),
566 ];
567
568 index.add_documents(&docs).await.unwrap();
569
570 let results = index.search("hello", Some("ns1"), 10).unwrap();
572 assert_eq!(results.len(), 1);
573 assert_eq!(results[0].0, "doc1");
574 assert_eq!(results[0].1, "ns1");
575 }
576
577 #[tokio::test]
578 async fn test_bm25_delete_documents_removes_exact_id_matches() {
579 let temp_dir = TempDir::new().unwrap();
580 let config = BM25Config::default().with_path(temp_dir.path().to_str().unwrap());
581
582 let index = BM25Index::new(&config).unwrap();
583
584 let docs = vec![
585 (
586 "doc1".to_string(),
587 "team:alpha".to_string(),
588 "shared search term".to_string(),
589 ),
590 (
591 "doc2".to_string(),
592 "team:alpha".to_string(),
593 "shared search term".to_string(),
594 ),
595 ];
596
597 index.add_documents(&docs).await.unwrap();
598 assert_eq!(index.search("shared", None, 10).unwrap().len(), 2);
599
600 let deleted = index.delete_documents(&["doc1".to_string()]).await.unwrap();
601 assert_eq!(deleted, 1);
602
603 let results = index.search("shared", None, 10).unwrap();
604 assert_eq!(results.len(), 1);
605 assert_eq!(results[0].0, "doc2");
606 }
607
608 #[tokio::test]
609 async fn test_bm25_purge_namespace_matches_exact_string() {
610 let temp_dir = TempDir::new().unwrap();
611 let config = BM25Config::default().with_path(temp_dir.path().to_str().unwrap());
612
613 let index = BM25Index::new(&config).unwrap();
614
615 let docs = vec![
616 (
617 "doc1".to_string(),
618 "team:alpha".to_string(),
619 "shared search term".to_string(),
620 ),
621 (
622 "doc2".to_string(),
623 "team:beta".to_string(),
624 "shared search term".to_string(),
625 ),
626 ];
627
628 index.add_documents(&docs).await.unwrap();
629 assert_eq!(index.search("shared", None, 10).unwrap().len(), 2);
630
631 let deleted = index.delete_namespace_term("team:alpha").await.unwrap();
632 assert_eq!(deleted, 1);
633
634 assert!(
635 index
636 .search("shared", Some("team:alpha"), 10)
637 .unwrap()
638 .is_empty()
639 );
640
641 let remaining = index.search("shared", None, 10).unwrap();
642 assert_eq!(remaining.len(), 1);
643 assert_eq!(remaining[0].0, "doc2");
644 assert_eq!(remaining[0].1, "team:beta");
645 }
646
647 #[tokio::test]
648 async fn test_bm25_lock_release() {
649 let temp_dir = TempDir::new().unwrap();
651 let path = temp_dir.path().to_str().unwrap();
652
653 let config = BM25Config::default().with_path(path);
654 let index1 = BM25Index::new(&config).unwrap();
655
656 index1
658 .add_documents(&[(
659 "doc1".to_string(),
660 "ns".to_string(),
661 "hello world".to_string(),
662 )])
663 .await
664 .unwrap();
665
666 drop(index1);
668
669 let config2 = BM25Config::default().with_path(path);
671 let index2 = BM25Index::new(&config2).unwrap();
672
673 index2
675 .add_documents(&[(
676 "doc2".to_string(),
677 "ns".to_string(),
678 "hello there".to_string(),
679 )])
680 .await
681 .unwrap();
682
683 let results = index2.search("hello", None, 10).unwrap();
685 assert_eq!(results.len(), 2);
686 }
687
688 #[test]
689 fn test_escape_query() {
690 assert_eq!(BM25Index::escape_query("hello world"), "hello world");
691 assert_eq!(BM25Index::escape_query("hello+world"), "hello\\+world");
692 assert_eq!(BM25Index::escape_query("test:query"), "test\\:query");
693 }
694}