Skip to main content

spark_bert/
tf_term_query.rs

1use tantivy::postings::{Postings, SegmentPostings};
2use tantivy::query::{EmptyScorer, EnableScoring, Explanation, Query, Scorer, Weight};
3use tantivy::schema::IndexRecordOption;
4use tantivy::{DocId, DocSet, Score, SegmentReader, Term};
5
6// TF Scorer
7struct TfScorer {
8    postings: SegmentPostings,
9}
10
11impl DocSet for TfScorer {
12    fn advance(&mut self) -> DocId {
13        self.postings.advance();
14        self.doc()
15    }
16
17    fn doc(&self) -> DocId {
18        self.postings.doc()
19    }
20
21    fn size_hint(&self) -> u32 {
22        unimplemented!()
23    }
24}
25
26impl Scorer for TfScorer {
27    fn score(&mut self) -> Score {
28        self.postings.term_freq() as Score
29    }
30}
31
32#[derive(Debug, Clone)]
33pub struct TfTermQuery {
34    term: Term,
35}
36
37impl TfTermQuery {
38    pub fn new(term: Term) -> Self {
39        Self { term }
40    }
41}
42
43struct TfWeight {
44    term: Term,
45}
46
47impl Weight for TfWeight {
48    fn scorer(&self, reader: &SegmentReader, boost: Score) -> tantivy::Result<Box<dyn Scorer>> {
49        let inv = reader.inverted_index(self.term.field())?;
50        if let Some(info) = inv.get_term_info(&self.term)? {
51            let postings = inv.read_postings_from_terminfo(&info, IndexRecordOption::WithFreqs)?;
52            Ok(Box::new(TfScorer { postings }))
53        } else {
54            Ok(Box::new(EmptyScorer))
55        }
56    }
57
58    fn explain(&self, reader: &SegmentReader, doc: DocId) -> tantivy::Result<Explanation> {
59        unimplemented!()
60    }
61}
62
63impl Query for TfTermQuery {
64    fn weight(&self, enable_scoring: EnableScoring<'_>) -> tantivy::Result<Box<dyn Weight>> {
65        let weight = TfWeight {
66            term: self.term.clone(),
67        };
68        tantivy::Result::Ok(Box::new(weight))
69    }
70}