spark_bert/
tf_term_query.rs1use tantivy::postings::{Postings, SegmentPostings};
2use tantivy::query::{EmptyScorer, EnableScoring, Explanation, Query, Scorer, Weight};
3use tantivy::schema::IndexRecordOption;
4use tantivy::{DocId, DocSet, Score, SegmentReader, Term};
5
6struct TfScorer {
8 postings: SegmentPostings,
9}
10
11impl DocSet for TfScorer {
12 fn advance(&mut self) -> DocId {
13 self.postings.advance();
14 self.doc()
15 }
16
17 fn doc(&self) -> DocId {
18 self.postings.doc()
19 }
20
21 fn size_hint(&self) -> u32 {
22 unimplemented!()
23 }
24}
25
26impl Scorer for TfScorer {
27 fn score(&mut self) -> Score {
28 self.postings.term_freq() as Score
29 }
30}
31
32#[derive(Debug, Clone)]
33pub struct TfTermQuery {
34 term: Term,
35}
36
37impl TfTermQuery {
38 pub fn new(term: Term) -> Self {
39 Self { term }
40 }
41}
42
43struct TfWeight {
44 term: Term,
45}
46
47impl Weight for TfWeight {
48 fn scorer(&self, reader: &SegmentReader, boost: Score) -> tantivy::Result<Box<dyn Scorer>> {
49 let inv = reader.inverted_index(self.term.field())?;
50 if let Some(info) = inv.get_term_info(&self.term)? {
51 let postings = inv.read_postings_from_terminfo(&info, IndexRecordOption::WithFreqs)?;
52 Ok(Box::new(TfScorer { postings }))
53 } else {
54 Ok(Box::new(EmptyScorer))
55 }
56 }
57
58 fn explain(&self, reader: &SegmentReader, doc: DocId) -> tantivy::Result<Explanation> {
59 unimplemented!()
60 }
61}
62
63impl Query for TfTermQuery {
64 fn weight(&self, enable_scoring: EnableScoring<'_>) -> tantivy::Result<Box<dyn Weight>> {
65 let weight = TfWeight {
66 term: self.term.clone(),
67 };
68 tantivy::Result::Ok(Box::new(weight))
69 }
70}