1use std::cmp::{Ordering, Reverse};
41use std::collections::{BinaryHeap, HashMap, HashSet};
42
43use parking_lot::RwLock;
44
45use crate::bm25::{BM25Config, BM25Scorer, tokenize_minimal};
46
47pub type DocId = u64;
53
54pub type Position = u32;
56
57pub type TermFreq = u32;
59
60struct ScoredDoc {
66 score: f32,
67 doc_id: DocId,
68}
69
70impl PartialEq for ScoredDoc {
71 fn eq(&self, other: &Self) -> bool {
72 self.cmp(other) == Ordering::Equal
73 }
74}
75
76impl Eq for ScoredDoc {}
77
78impl Ord for ScoredDoc {
79 fn cmp(&self, other: &Self) -> Ordering {
80 self.score
81 .total_cmp(&other.score)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for ScoredDoc {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92#[derive(Debug, Clone)]
98pub struct Posting {
99 pub doc_id: DocId,
101
102 pub term_freq: TermFreq,
104
105 pub positions: Option<Vec<Position>>,
107}
108
109impl Posting {
110 pub fn new(doc_id: DocId, term_freq: TermFreq) -> Self {
112 Self {
113 doc_id,
114 term_freq,
115 positions: None,
116 }
117 }
118
119 pub fn with_positions(doc_id: DocId, positions: Vec<Position>) -> Self {
121 Self {
122 doc_id,
123 term_freq: positions.len() as TermFreq,
124 positions: Some(positions),
125 }
126 }
127}
128
129#[derive(Debug, Clone, Default)]
131pub struct PostingList {
132 postings: Vec<Posting>,
134}
135
136impl PostingList {
137 pub fn new() -> Self {
139 Self {
140 postings: Vec::new(),
141 }
142 }
143
144 pub fn add(&mut self, posting: Posting) {
146 match self
147 .postings
148 .binary_search_by_key(&posting.doc_id, |p| p.doc_id)
149 {
150 Ok(idx) => {
151 self.postings[idx] = posting;
153 }
154 Err(idx) => {
155 self.postings.insert(idx, posting);
157 }
158 }
159 }
160
161 pub fn get(&self, doc_id: DocId) -> Option<&Posting> {
163 self.postings
164 .binary_search_by_key(&doc_id, |p| p.doc_id)
165 .ok()
166 .map(|idx| &self.postings[idx])
167 }
168
169 pub fn doc_freq(&self) -> usize {
171 self.postings.len()
172 }
173
174 pub fn iter(&self) -> impl Iterator<Item = &Posting> {
176 self.postings.iter()
177 }
178
179 pub fn doc_ids(&self) -> Vec<DocId> {
181 self.postings.iter().map(|p| p.doc_id).collect()
182 }
183}
184
185#[derive(Debug, Clone)]
191pub struct DocumentInfo {
192 pub length: u32,
194
195 pub term_freqs: HashMap<String, TermFreq>,
197}
198
199pub struct InvertedIndex {
205 index: RwLock<HashMap<String, PostingList>>,
207
208 docs: RwLock<HashMap<DocId, DocumentInfo>>,
210
211 scorer: RwLock<BM25Scorer>,
213
214 next_doc_id: RwLock<DocId>,
216
217 store_positions: bool,
219}
220
221impl InvertedIndex {
222 pub fn new(config: BM25Config) -> Self {
224 Self {
225 index: RwLock::new(HashMap::new()),
226 docs: RwLock::new(HashMap::new()),
227 scorer: RwLock::new(BM25Scorer::new(config)),
228 next_doc_id: RwLock::new(0),
229 store_positions: false,
230 }
231 }
232
233 pub fn with_positions(mut self) -> Self {
235 self.store_positions = true;
236 self
237 }
238
239 pub fn add_document(&self, text: &str) -> DocId {
243 let tokens = tokenize_minimal(text);
244 self.add_document_tokens(&tokens)
245 }
246
247 pub fn add_document_with_id(&self, doc_id: DocId, text: &str) {
249 let tokens = tokenize_minimal(text);
250 self.add_document_tokens_with_id(doc_id, &tokens);
251 }
252
253 pub fn clear(&self) {
256 let config = self.scorer.read().config();
257 self.index.write().clear();
258 self.docs.write().clear();
259 *self.scorer.write() = BM25Scorer::new(config);
260 *self.next_doc_id.write() = 0;
261 }
262
263 pub fn rebuild_from_documents<'a, I>(&self, documents: I)
285 where
286 I: IntoIterator<Item = (DocId, &'a str)>,
287 {
288 self.clear();
289 let mut max_id: Option<DocId> = None;
290 for (doc_id, text) in documents {
291 self.add_document_with_id(doc_id, text);
292 max_id = Some(max_id.map_or(doc_id, |m| m.max(doc_id)));
293 }
294 if let Some(m) = max_id {
297 *self.next_doc_id.write() = m + 1;
298 }
299 }
300
301 pub fn add_document_tokens(&self, tokens: &[String]) -> DocId {
303 let doc_id = {
304 let mut next = self.next_doc_id.write();
305 let id = *next;
306 *next += 1;
307 id
308 };
309
310 self.add_document_tokens_with_id(doc_id, tokens);
311 doc_id
312 }
313
314 pub fn add_document_tokens_with_id(&self, doc_id: DocId, tokens: &[String]) {
316 let mut term_freqs: HashMap<String, TermFreq> = HashMap::new();
318 let mut term_positions: HashMap<String, Vec<Position>> = HashMap::new();
319
320 for (pos, token) in tokens.iter().enumerate() {
321 *term_freqs.entry(token.clone()).or_insert(0) += 1;
322 if self.store_positions {
323 term_positions
324 .entry(token.clone())
325 .or_default()
326 .push(pos as Position);
327 }
328 }
329
330 {
332 let mut index = self.index.write();
333 for (term, tf) in &term_freqs {
334 let posting = if self.store_positions {
335 Posting::with_positions(
336 doc_id,
337 term_positions.get(term).cloned().unwrap_or_default(),
338 )
339 } else {
340 Posting::new(doc_id, *tf)
341 };
342
343 index.entry(term.clone()).or_default().add(posting);
344 }
345 }
346
347 {
349 let mut docs = self.docs.write();
350 docs.insert(
351 doc_id,
352 DocumentInfo {
353 length: tokens.len() as u32,
354 term_freqs,
355 },
356 );
357 }
358
359 {
361 let mut scorer = self.scorer.write();
362 scorer.add_document(tokens.iter().map(|s| s.as_str()));
363 }
364 }
365
366 pub fn remove_document(&self, doc_id: DocId) -> bool {
368 let doc_info = {
369 let mut docs = self.docs.write();
370 docs.remove(&doc_id)
371 };
372
373 if let Some(info) = doc_info {
374 {
377 let mut index = self.index.write();
378 for term in info.term_freqs.keys() {
379 let now_empty = if let Some(posting_list) = index.get_mut(term) {
380 posting_list.postings.retain(|p| p.doc_id != doc_id);
381 posting_list.postings.is_empty()
382 } else {
383 false
384 };
385 if now_empty {
386 index.remove(term);
387 }
388 }
389 }
390
391 {
394 let mut scorer = self.scorer.write();
395 scorer.remove_document(
396 info.term_freqs.keys().map(|s| s.as_str()),
397 info.length as usize,
398 );
399 }
400 true
401 } else {
402 false
403 }
404 }
405
406 pub fn search(&self, query: &str, limit: usize) -> Vec<(DocId, f32)> {
410 let query_tokens = tokenize_minimal(query);
411 if query_tokens.is_empty() {
412 return Vec::new();
413 }
414
415 self.search_tokens(&query_tokens, limit)
416 }
417
418 pub fn search_tokens(&self, query_tokens: &[String], limit: usize) -> Vec<(DocId, f32)> {
420 if query_tokens.is_empty() {
421 return Vec::new();
422 }
423
424 let index = self.index.read();
425 let docs = self.docs.read();
426 let scorer = self.scorer.read();
427
428 let mut candidates: HashSet<DocId> = HashSet::new();
430 for token in query_tokens {
431 if let Some(posting_list) = index.get(token) {
432 for posting in posting_list.iter() {
433 candidates.insert(posting.doc_id);
434 }
435 }
436 }
437
438 let mut heap: BinaryHeap<Reverse<ScoredDoc>> = BinaryHeap::with_capacity(limit + 1);
443 for doc_id in candidates {
444 let Some(doc_info) = docs.get(&doc_id) else {
445 continue;
446 };
447 let score = scorer.score_with_tf_u32(
448 query_tokens,
449 &doc_info.term_freqs,
450 doc_info.length as usize,
451 );
452 if score <= 0.0 {
453 continue;
454 }
455 if limit == 0 {
456 continue;
457 }
458 heap.push(Reverse(ScoredDoc { score, doc_id }));
459 if heap.len() > limit {
460 heap.pop();
461 }
462 }
463
464 let mut results: Vec<(DocId, f32)> = heap
469 .into_iter()
470 .map(|Reverse(sd)| (sd.doc_id, sd.score))
471 .collect();
472 results.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
473
474 results
475 }
476
477 pub fn get_posting_list(&self, term: &str) -> Option<PostingList> {
479 self.index.read().get(&term.to_lowercase()).cloned()
480 }
481
482 pub fn num_documents(&self) -> usize {
484 self.docs.read().len()
485 }
486
487 pub fn vocab_size(&self) -> usize {
489 self.index.read().len()
490 }
491
492 pub fn get_document_info(&self, doc_id: DocId) -> Option<DocumentInfo> {
494 self.docs.read().get(&doc_id).cloned()
495 }
496
497 pub fn has_document(&self, doc_id: DocId) -> bool {
499 self.docs.read().contains_key(&doc_id)
500 }
501}
502
503pub struct InvertedIndexBuilder {
509 config: BM25Config,
510 store_positions: bool,
511}
512
513impl InvertedIndexBuilder {
514 pub fn new() -> Self {
516 Self {
517 config: BM25Config::default(),
518 store_positions: false,
519 }
520 }
521
522 pub fn with_config(mut self, config: BM25Config) -> Self {
524 self.config = config;
525 self
526 }
527
528 pub fn with_positions(mut self) -> Self {
530 self.store_positions = true;
531 self
532 }
533
534 pub fn build<I>(self, documents: I) -> InvertedIndex
536 where
537 I: IntoIterator<Item = (DocId, String)>,
538 {
539 let index = if self.store_positions {
540 InvertedIndex::new(self.config).with_positions()
541 } else {
542 InvertedIndex::new(self.config)
543 };
544
545 for (doc_id, text) in documents {
546 index.add_document_with_id(doc_id, &text);
547 }
548
549 index
550 }
551}
552
553impl Default for InvertedIndexBuilder {
554 fn default() -> Self {
555 Self::new()
556 }
557}
558
559#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_posting_list() {
569 let mut list = PostingList::new();
570
571 list.add(Posting::new(1, 2));
572 list.add(Posting::new(3, 1));
573 list.add(Posting::new(2, 3));
574
575 assert_eq!(list.doc_freq(), 3);
576
577 let ids = list.doc_ids();
579 assert_eq!(ids, vec![1, 2, 3]);
580
581 let p = list.get(2).unwrap();
583 assert_eq!(p.term_freq, 3);
584 }
585
586 #[test]
587 fn test_add_document() {
588 let index = InvertedIndex::new(BM25Config::default());
589
590 let doc1 = index.add_document("hello world");
591 let doc2 = index.add_document("hello there");
592
593 assert_eq!(doc1, 0);
594 assert_eq!(doc2, 1);
595 assert_eq!(index.num_documents(), 2);
596
597 let hello_list = index.get_posting_list("hello").unwrap();
599 assert_eq!(hello_list.doc_freq(), 2);
600 }
601
602 #[test]
603 fn test_search() {
604 let index = InvertedIndex::new(BM25Config::default());
605
606 index.add_document("the quick brown fox jumps over the lazy dog");
607 index.add_document("quick quick quick fox"); index.add_document("lazy lazy lazy dog"); let results = index.search("quick", 10);
612 assert!(!results.is_empty());
613
614 assert_eq!(results[0].0, 1); }
617
618 #[test]
619 fn test_search_multi_term() {
620 let index = InvertedIndex::new(BM25Config::default());
621
622 index.add_document("apple banana cherry");
623 index.add_document("apple banana");
624 index.add_document("apple");
625
626 let results = index.search("apple banana cherry", 10);
628
629 assert_eq!(results[0].0, 0);
631 }
632
633 #[test]
634 fn test_search_topk_bound_matches_full_sort() {
635 let index = InvertedIndex::new(BM25Config::default());
639 for i in 0..20 {
640 let body = std::iter::repeat("alpha")
642 .take(i + 1)
643 .collect::<Vec<_>>()
644 .join(" ");
645 index.add_document(&format!("{body} doc{i}"));
646 }
647
648 let limit = 5;
649 let topk = index.search("alpha", limit);
650 assert_eq!(topk.len(), limit, "must return exactly `limit` results");
651
652 for w in topk.windows(2) {
654 assert!(
655 w[0].1 >= w[1].1,
656 "results must be sorted by score descending"
657 );
658 }
659
660 let full = index.search("alpha", 1000);
663 let full_prefix: Vec<u64> = full.iter().take(limit).map(|(id, _)| *id).collect();
664 let topk_ids: Vec<u64> = topk.iter().map(|(id, _)| *id).collect();
665 assert_eq!(
666 topk_ids, full_prefix,
667 "bounded top-k must equal full-sort prefix"
668 );
669 }
670
671 #[test]
672 fn test_rebuild_reproduces_index() {
673 let corpus: Vec<(u64, &str)> = vec![
677 (10, "the quick brown fox"),
678 (11, "the lazy dog sleeps"),
679 (12, "quick foxes jump high"),
680 (13, "lazy dogs and quick cats"),
681 ];
682
683 let reference = InvertedIndex::new(BM25Config::default());
685 for (id, text) in &corpus {
686 reference.add_document_with_id(*id, text);
687 }
688
689 let rebuilt = InvertedIndex::new(BM25Config::default());
692 rebuilt.add_document_with_id(99, "noise document that should vanish");
693 rebuilt.add_document_with_id(98, "more transient noise quick fox");
694 rebuilt.remove_document(99);
695 rebuilt.rebuild_from_documents(corpus.iter().map(|(id, t)| (*id, *t)));
696
697 assert!(!rebuilt.has_document(99));
699 assert!(!rebuilt.has_document(98));
700 for (id, _) in &corpus {
701 assert!(rebuilt.has_document(*id));
702 }
703
704 for q in ["quick", "lazy dog", "fox", "the quick brown"] {
706 assert_eq!(
707 rebuilt.search(q, 10),
708 reference.search(q, 10),
709 "rebuilt ranking diverges for query {q:?}",
710 );
711 }
712
713 let next = rebuilt.add_document("brand new quick doc");
715 assert_eq!(next, 14, "auto-id must resume one past max restored id");
716 }
717
718 #[test]
719 fn test_remove_document() {
720 let index = InvertedIndex::new(BM25Config::default());
721
722 let doc1 = index.add_document("hello world");
723 let doc2 = index.add_document("hello there");
724
725 assert!(index.has_document(doc1));
726 assert!(index.remove_document(doc1));
727 assert!(!index.has_document(doc1));
728
729 let hello_list = index.get_posting_list("hello").unwrap();
731 assert_eq!(hello_list.doc_freq(), 1);
732
733 assert!(index.get_posting_list("world").is_none());
736 }
737
738 #[test]
739 fn test_add_remove_equals_never_added() {
740 let with_removed = InvertedIndex::new(BM25Config::default());
743 with_removed.add_document_with_id(1, "the quick brown fox");
744 with_removed.add_document_with_id(2, "lazy dog sleeps all day");
745 let transient = with_removed.add_document("ephemeral zebra quagga");
746 assert!(with_removed.remove_document(transient));
747
748 let never_added = InvertedIndex::new(BM25Config::default());
749 never_added.add_document_with_id(1, "the quick brown fox");
750 never_added.add_document_with_id(2, "lazy dog sleeps all day");
751
752 assert_eq!(with_removed.num_documents(), never_added.num_documents());
754 assert_eq!(with_removed.vocab_size(), never_added.vocab_size());
755
756 assert!(with_removed.get_posting_list("zebra").is_none());
758 assert!(with_removed.get_posting_list("quagga").is_none());
759
760 for q in ["quick", "dog", "fox sleeps"] {
762 let a = with_removed.search(q, 10);
763 let b = never_added.search(q, 10);
764 assert_eq!(a.len(), b.len(), "result-count mismatch for {q:?}");
765 for (x, y) in a.iter().zip(b.iter()) {
766 assert_eq!(x.0, y.0, "doc_id mismatch for {q:?}");
767 assert_eq!(x.1.to_bits(), y.1.to_bits(), "score mismatch for {q:?}");
768 }
769 }
770 }
771
772 #[test]
773 fn test_builder() {
774 let documents = vec![
775 (0, "hello world".to_string()),
776 (1, "hello there".to_string()),
777 (2, "goodbye world".to_string()),
778 ];
779
780 let index = InvertedIndexBuilder::new()
781 .with_config(BM25Config::lucene())
782 .build(documents);
783
784 assert_eq!(index.num_documents(), 3);
785 assert!(index.vocab_size() > 0);
786 }
787
788 #[test]
789 fn test_positions() {
790 let index = InvertedIndex::new(BM25Config::default()).with_positions();
791
792 let doc_id = index.add_document("hello world hello");
793
794 let hello_list = index.get_posting_list("hello").unwrap();
795 let posting = hello_list.get(doc_id).unwrap();
796
797 assert_eq!(posting.positions, Some(vec![0, 2]));
798 }
799}