summa_core/components/queries/
exists_query.rs

1use std::ops::RangeInclusive;
2
3use tantivy::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight};
4use tantivy::schema::{Field, IndexRecordOption};
5use tantivy::{DocId, Result, Score, SegmentReader, TantivyError, Term};
6use tantivy_common::BitSet;
7
8/// An Exists Query matches all of the documents
9/// containing a specific indexed field.
10///
11/// ```rust
12/// use tantivy::collector::Count;
13/// use summa_core::components::queries::ExistsQuery;
14/// use tantivy::schema::{Schema, TEXT};
15/// use tantivy::{doc, Index};
16///
17/// # fn test() -> tantivy::Result<()> {
18/// let mut schema_builder = Schema::builder();
19/// let title = schema_builder.add_text_field("title", TEXT);
20/// let author = schema_builder.add_text_field("author", TEXT);
21/// let schema = schema_builder.build();
22/// let index = Index::create_in_ram(schema);
23/// {
24///     let mut index_writer = index.writer(15_000_000)?;
25///     index_writer.add_document(doc!(
26///         title => "The Name of the Wind",
27///         author => "Patrick Rothfuss"
28///     ))?;
29///     index_writer.add_document(doc!(
30///         title => "The Diary of Muadib",
31///     ))?;
32///     index_writer.add_document(doc!(
33///         title => "A Dairy Cow",
34///         author => "John Webster"
35///     ))?;
36///     index_writer.commit()?;
37/// }
38///
39/// let reader = index.reader()?;
40/// let searcher = reader.searcher();
41///
42/// let query = ExistsQuery::new(author, "");
43/// let count = searcher.search(&query, &Count)?;
44/// assert_eq!(count, 2);
45/// Ok(())
46/// # }
47/// # assert!(test().is_ok());
48/// ```
49
50pub const JSON_SEGMENT_UPPER_TERMINATOR: u8 = 2u8;
51pub const JSON_SEGMENT_UPPER_TERMINATOR_STR: &str = unsafe { std::str::from_utf8_unchecked(&[JSON_SEGMENT_UPPER_TERMINATOR]) };
52
53#[derive(Clone, Debug)]
54pub struct ExistsQuery {
55    field: Field,
56    full_path: String,
57}
58
59impl ExistsQuery {
60    /// Creates a new ExistsQuery with a given field
61    pub fn new(field: Field, full_path: &str) -> Self {
62        ExistsQuery {
63            field,
64            full_path: full_path.to_string(),
65        }
66    }
67}
68
69#[async_trait]
70impl Query for ExistsQuery {
71    fn weight(&self, _: EnableScoring<'_>) -> Result<Box<dyn Weight>> {
72        Ok(Box::new(ExistsWeight {
73            field: self.field,
74            full_path: self.full_path.clone(),
75        }))
76    }
77
78    async fn weight_async(&self, enable_scoring: EnableScoring<'_>) -> Result<Box<dyn Weight>> {
79        self.weight(enable_scoring)
80    }
81}
82
83/// Weight associated with the `ExistsQuery` query.
84pub struct ExistsWeight {
85    field: Field,
86    full_path: String,
87}
88
89impl ExistsWeight {
90    fn get_json_term(&self, json_path: &str) -> Term {
91        Term::from_field_json_path(self.field, json_path, true)
92    }
93    fn generate_json_term_range(&self) -> RangeInclusive<Term> {
94        let start_term_str = format!("{}\0", self.full_path);
95        let end_term_str = format!("{}{}", self.full_path, JSON_SEGMENT_UPPER_TERMINATOR_STR);
96        self.get_json_term(&start_term_str)..=self.get_json_term(&end_term_str)
97    }
98}
99#[async_trait]
100impl Weight for ExistsWeight {
101    fn scorer(&self, reader: &SegmentReader, boost: Score) -> Result<Box<dyn Scorer>> {
102        let max_doc = reader.max_doc();
103        let mut doc_bitset = BitSet::with_max_value(max_doc);
104
105        let inverted_index = reader.inverted_index(self.field)?;
106        let terms = inverted_index.terms();
107        let mut term_stream = if self.full_path.is_empty() {
108            terms.stream()?
109        } else {
110            let json_term_range = self.generate_json_term_range();
111            terms
112                .range()
113                .ge(json_term_range.start().serialized_value_bytes())
114                .le(json_term_range.end().serialized_value_bytes())
115                .into_stream()?
116        };
117        while term_stream.advance() {
118            let term_info = term_stream.value();
119
120            let mut block_segment_postings = inverted_index.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
121
122            loop {
123                let docs = block_segment_postings.docs();
124
125                if docs.is_empty() {
126                    break;
127                }
128                for &doc in block_segment_postings.docs() {
129                    doc_bitset.insert(doc);
130                }
131                block_segment_postings.advance();
132            }
133        }
134
135        let doc_bitset = BitSetDocSet::from(doc_bitset);
136        Ok(Box::new(ConstScorer::new(doc_bitset, boost)))
137    }
138
139    fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
140        let mut scorer = self.scorer(reader, 1.0)?;
141        if scorer.seek(doc) != doc {
142            return Err(TantivyError::InvalidArgument(format!("Document #({}) does not match", doc)));
143        }
144        Ok(Explanation::new("ExistsQuery", 1.0))
145    }
146
147    async fn scorer_async(&self, reader: &SegmentReader, boost: Score) -> Result<Box<dyn Scorer>> {
148        let max_doc = reader.max_doc();
149        let mut doc_bitset = BitSet::with_max_value(max_doc);
150
151        let inverted_index = reader.inverted_index_async(self.field).await?;
152        let terms = inverted_index.terms();
153        let mut term_stream = if self.full_path.is_empty() {
154            terms.range().into_stream_async().await?
155        } else {
156            let json_term_range = self.generate_json_term_range();
157            terms
158                .range()
159                .ge(json_term_range.start().serialized_value_bytes())
160                .le(json_term_range.end().serialized_value_bytes())
161                .into_stream_async()
162                .await?
163        };
164        while term_stream.advance() {
165            let term_info = term_stream.value();
166
167            let mut block_segment_postings = inverted_index
168                .read_block_postings_from_terminfo_async(term_info, IndexRecordOption::Basic)
169                .await?;
170
171            loop {
172                let docs = block_segment_postings.docs();
173
174                if docs.is_empty() {
175                    break;
176                }
177                for &doc in block_segment_postings.docs() {
178                    doc_bitset.insert(doc);
179                }
180                block_segment_postings.advance();
181            }
182        }
183
184        let doc_bitset = BitSetDocSet::from(doc_bitset);
185        Ok(Box::new(ConstScorer::new(doc_bitset, boost)))
186    }
187}