tantivy_sstable/
streamer.rs

1use std::io;
2use std::ops::Bound;
3
4use tantivy_fst::automaton::AlwaysMatch;
5use tantivy_fst::Automaton;
6
7use crate::dictionary::Dictionary;
8use crate::{DeltaReader, SSTable, TermOrdinal};
9
10/// `StreamerBuilder` is a helper object used to define
11/// a range of terms that should be streamed.
12pub struct StreamerBuilder<'a, TSSTable, A = AlwaysMatch>
13where
14    A: Automaton,
15    A::State: Clone,
16    TSSTable: SSTable,
17{
18    term_dict: &'a Dictionary<TSSTable>,
19    automaton: A,
20    lower: Bound<Vec<u8>>,
21    upper: Bound<Vec<u8>>,
22    limit: Option<u64>,
23}
24
25fn bound_as_byte_slice(bound: &Bound<Vec<u8>>) -> Bound<&[u8]> {
26    match bound.as_ref() {
27        Bound::Included(key) => Bound::Included(key.as_slice()),
28        Bound::Excluded(key) => Bound::Excluded(key.as_slice()),
29        Bound::Unbounded => Bound::Unbounded,
30    }
31}
32
33impl<'a, TSSTable, A> StreamerBuilder<'a, TSSTable, A>
34where
35    A: Automaton,
36    A::State: Clone,
37    TSSTable: SSTable,
38{
39    pub(crate) fn new(term_dict: &'a Dictionary<TSSTable>, automaton: A) -> Self {
40        StreamerBuilder {
41            term_dict,
42            automaton,
43            lower: Bound::Unbounded,
44            upper: Bound::Unbounded,
45            limit: None,
46        }
47    }
48
49    /// Limit the range to terms greater or equal to the bound
50    pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
51        self.lower = Bound::Included(bound.as_ref().to_owned());
52        self
53    }
54
55    /// Limit the range to terms strictly greater than the bound
56    pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
57        self.lower = Bound::Excluded(bound.as_ref().to_owned());
58        self
59    }
60
61    /// Limit the range to terms lesser or equal to the bound
62    pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
63        self.upper = Bound::Included(bound.as_ref().to_owned());
64        self
65    }
66
67    /// Limit the range to terms lesser or equal to the bound
68    pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
69        self.upper = Bound::Excluded(bound.as_ref().to_owned());
70        self
71    }
72
73    /// Load no more data than what's required to to get `limit`
74    /// matching entries.
75    ///
76    /// The resulting [`Streamer`] can still return marginaly
77    /// more than `limit` elements.
78    pub fn limit(mut self, limit: u64) -> Self {
79        self.limit = Some(limit);
80        self
81    }
82
83    fn delta_reader(&self) -> io::Result<DeltaReader<TSSTable::ValueReader>> {
84        let key_range = (
85            bound_as_byte_slice(&self.lower),
86            bound_as_byte_slice(&self.upper),
87        );
88        self.term_dict
89            .sstable_delta_reader_for_key_range(key_range, self.limit)
90    }
91
92    async fn delta_reader_async(&self) -> io::Result<DeltaReader<TSSTable::ValueReader>> {
93        let key_range = (
94            bound_as_byte_slice(&self.lower),
95            bound_as_byte_slice(&self.upper),
96        );
97        self.term_dict
98            .sstable_delta_reader_for_key_range_async(key_range, self.limit)
99            .await
100    }
101
102    fn into_stream_given_delta_reader(
103        self,
104        delta_reader: DeltaReader<<TSSTable as SSTable>::ValueReader>,
105    ) -> io::Result<Streamer<'a, TSSTable, A>> {
106        let start_state = self.automaton.start();
107        let start_key = bound_as_byte_slice(&self.lower);
108
109        let first_term = match start_key {
110            Bound::Included(key) | Bound::Excluded(key) => self
111                .term_dict
112                .sstable_index
113                .get_block_with_key(key)
114                .map(|block| block.first_ordinal)
115                .unwrap_or(0),
116            Bound::Unbounded => 0,
117        };
118
119        Ok(Streamer {
120            automaton: self.automaton,
121            states: vec![start_state],
122            delta_reader,
123            key: Vec::new(),
124            term_ord: first_term.checked_sub(1),
125            lower_bound: self.lower,
126            upper_bound: self.upper,
127            _lifetime: std::marker::PhantomData,
128        })
129    }
130
131    /// See `into_stream(..)`
132    pub async fn into_stream_async(self) -> io::Result<Streamer<'a, TSSTable, A>> {
133        let delta_reader = self.delta_reader_async().await?;
134        self.into_stream_given_delta_reader(delta_reader)
135    }
136
137    /// Creates the stream corresponding to the range
138    /// of terms defined using the `StreamerBuilder`.
139    pub fn into_stream(self) -> io::Result<Streamer<'a, TSSTable, A>> {
140        let delta_reader = self.delta_reader()?;
141        self.into_stream_given_delta_reader(delta_reader)
142    }
143}
144
145/// `Streamer` acts as a cursor over a range of terms of a segment.
146/// Terms are guaranteed to be sorted.
147pub struct Streamer<'a, TSSTable, A = AlwaysMatch>
148where
149    A: Automaton,
150    A::State: Clone,
151    TSSTable: SSTable,
152{
153    automaton: A,
154    states: Vec<A::State>,
155    delta_reader: crate::DeltaReader<TSSTable::ValueReader>,
156    key: Vec<u8>,
157    term_ord: Option<TermOrdinal>,
158    lower_bound: Bound<Vec<u8>>,
159    upper_bound: Bound<Vec<u8>>,
160    // this field is used to please the type-interface of a dictionary in tantivy
161    _lifetime: std::marker::PhantomData<&'a ()>,
162}
163
164impl<'a, TSSTable> Streamer<'a, TSSTable, AlwaysMatch>
165where TSSTable: SSTable
166{
167    pub fn empty() -> Self {
168        Streamer {
169            automaton: AlwaysMatch,
170            states: Vec::new(),
171            delta_reader: DeltaReader::empty(),
172            key: Vec::new(),
173            term_ord: None,
174            lower_bound: Bound::Unbounded,
175            upper_bound: Bound::Unbounded,
176            _lifetime: std::marker::PhantomData,
177        }
178    }
179}
180
181impl<'a, TSSTable, A> Streamer<'a, TSSTable, A>
182where
183    A: Automaton,
184    A::State: Clone,
185    TSSTable: SSTable,
186{
187    /// Advance position the stream on the next item.
188    /// Before the first call to `.advance()`, the stream
189    /// is an uninitialized state.
190    pub fn advance(&mut self) -> bool {
191        while self.delta_reader.advance().unwrap() {
192            self.term_ord = Some(
193                self.term_ord
194                    .map(|term_ord| term_ord + 1u64)
195                    .unwrap_or(0u64),
196            );
197            let common_prefix_len = self.delta_reader.common_prefix_len();
198            self.states.truncate(common_prefix_len + 1);
199            self.key.truncate(common_prefix_len);
200            let mut state: A::State = self.states.last().unwrap().clone();
201            for &b in self.delta_reader.suffix() {
202                state = self.automaton.accept(&state, b);
203                self.states.push(state.clone());
204            }
205            self.key.extend_from_slice(self.delta_reader.suffix());
206            let match_lower_bound = match &self.lower_bound {
207                Bound::Unbounded => true,
208                Bound::Included(lower_bound_key) => lower_bound_key[..] <= self.key[..],
209                Bound::Excluded(lower_bound_key) => lower_bound_key[..] < self.key[..],
210            };
211            if !match_lower_bound {
212                continue;
213            }
214            // We match the lower key once. All subsequent keys will pass that bar.
215            self.lower_bound = Bound::Unbounded;
216            let match_upper_bound = match &self.upper_bound {
217                Bound::Unbounded => true,
218                Bound::Included(upper_bound_key) => upper_bound_key[..] >= self.key[..],
219                Bound::Excluded(upper_bound_key) => upper_bound_key[..] > self.key[..],
220            };
221            if !match_upper_bound {
222                return false;
223            }
224            if self.automaton.is_match(&state) {
225                return true;
226            }
227        }
228        false
229    }
230
231    /// Returns the `TermOrdinal` of the given term.
232    ///
233    /// May panic if the called as `.advance()` as never
234    /// been called before.
235    pub fn term_ord(&self) -> TermOrdinal {
236        self.term_ord.unwrap_or(0u64)
237    }
238
239    /// Accesses the current key.
240    ///
241    /// `.key()` should return the key that was returned
242    /// by the `.next()` method.
243    ///
244    /// If the end of the stream as been reached, and `.next()`
245    /// has been called and returned `None`, `.key()` remains
246    /// the value of the last key encountered.
247    ///
248    /// Before any call to `.next()`, `.key()` returns an empty array.
249    pub fn key(&self) -> &[u8] {
250        &self.key
251    }
252
253    /// Accesses the current value.
254    ///
255    /// Calling `.value()` after the end of the stream will return the
256    /// last `.value()` encountered.
257    ///
258    /// # Panics
259    ///
260    /// Calling `.value()` before the first call to `.advance()` returns
261    /// `V::default()`.
262    pub fn value(&self) -> &TSSTable::Value {
263        self.delta_reader.value()
264    }
265
266    /// Return the next `(key, value)` pair.
267    #[allow(clippy::should_implement_trait)]
268    pub fn next(&mut self) -> Option<(&[u8], &TSSTable::Value)> {
269        if self.advance() {
270            Some((self.key(), self.value()))
271        } else {
272            None
273        }
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use std::io;
280
281    use common::OwnedBytes;
282
283    use crate::{Dictionary, MonotonicU64SSTable};
284
285    fn create_test_dictionary() -> io::Result<Dictionary<MonotonicU64SSTable>> {
286        let mut dict_builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new())?;
287        dict_builder.insert(b"abaisance", &0)?;
288        dict_builder.insert(b"abalation", &1)?;
289        dict_builder.insert(b"abalienate", &2)?;
290        dict_builder.insert(b"abandon", &3)?;
291        let buffer = dict_builder.finish()?;
292        let owned_bytes = OwnedBytes::new(buffer);
293        Dictionary::from_bytes(owned_bytes)
294    }
295
296    #[test]
297    fn test_sstable_stream() -> io::Result<()> {
298        let dict = create_test_dictionary()?;
299        let mut streamer = dict.stream()?;
300        assert!(streamer.advance());
301        assert_eq!(streamer.key(), b"abaisance");
302        assert_eq!(streamer.value(), &0);
303        assert!(streamer.advance());
304        assert_eq!(streamer.key(), b"abalation");
305        assert_eq!(streamer.value(), &1);
306        assert!(streamer.advance());
307        assert_eq!(streamer.key(), b"abalienate");
308        assert_eq!(streamer.value(), &2);
309        assert!(streamer.advance());
310        assert_eq!(streamer.key(), b"abandon");
311        assert_eq!(streamer.value(), &3);
312        assert!(!streamer.advance());
313        Ok(())
314    }
315
316    #[test]
317    fn test_sstable_search() -> io::Result<()> {
318        let term_dict = create_test_dictionary()?;
319        let ptn = tantivy_fst::Regex::new("ab.*t.*").unwrap();
320        let mut term_streamer = term_dict.search(ptn).into_stream()?;
321        assert!(term_streamer.advance());
322        assert_eq!(term_streamer.key(), b"abalation");
323        assert_eq!(term_streamer.value(), &1u64);
324        assert!(term_streamer.advance());
325        assert_eq!(term_streamer.key(), b"abalienate");
326        assert_eq!(term_streamer.value(), &2u64);
327        assert!(!term_streamer.advance());
328        Ok(())
329    }
330}