summavy_sstable/
streamer.rs

1use std::io;
2use std::ops::Bound;
3
4use tantivy_fst::automaton::AlwaysMatch;
5use tantivy_fst::Automaton;
6
7use crate::dictionary::Dictionary;
8use crate::{SSTable, TermOrdinal};
9
10/// `StreamerBuilder` is a helper object used to define
11/// a range of terms that should be streamed.
12pub struct StreamerBuilder<'a, TSSTable, A = AlwaysMatch>
13where
14    A: Automaton,
15    A::State: Clone,
16    TSSTable: SSTable,
17{
18    term_dict: &'a Dictionary<TSSTable>,
19    automaton: A,
20    lower: Bound<Vec<u8>>,
21    upper: Bound<Vec<u8>>,
22}
23
24fn bound_as_byte_slice(bound: &Bound<Vec<u8>>) -> Bound<&[u8]> {
25    match bound.as_ref() {
26        Bound::Included(key) => Bound::Included(key.as_slice()),
27        Bound::Excluded(key) => Bound::Excluded(key.as_slice()),
28        Bound::Unbounded => Bound::Unbounded,
29    }
30}
31
32impl<'a, TSSTable, A> StreamerBuilder<'a, TSSTable, A>
33where
34    A: Automaton,
35    A::State: Clone,
36    TSSTable: SSTable,
37{
38    pub(crate) fn new(term_dict: &'a Dictionary<TSSTable>, automaton: A) -> Self {
39        StreamerBuilder {
40            term_dict,
41            automaton,
42            lower: Bound::Unbounded,
43            upper: Bound::Unbounded,
44        }
45    }
46
47    /// Limit the range to terms greater or equal to the bound
48    pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
49        self.lower = Bound::Included(bound.as_ref().to_owned());
50        self
51    }
52
53    /// Limit the range to terms strictly greater than the bound
54    pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
55        self.lower = Bound::Excluded(bound.as_ref().to_owned());
56        self
57    }
58
59    /// Limit the range to terms lesser or equal to the bound
60    pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
61        self.upper = Bound::Included(bound.as_ref().to_owned());
62        self
63    }
64
65    /// Limit the range to terms lesser or equal to the bound
66    pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
67        self.upper = Bound::Excluded(bound.as_ref().to_owned());
68        self
69    }
70
71    /// Creates the stream corresponding to the range
72    /// of terms defined using the `StreamerBuilder`.
73    pub fn into_stream(self) -> io::Result<Streamer<'a, TSSTable, A>> {
74        // TODO Optimize by skipping to the right first block.
75        let start_state = self.automaton.start();
76        let key_range = (
77            bound_as_byte_slice(&self.lower),
78            bound_as_byte_slice(&self.upper),
79        );
80        let delta_reader = self
81            .term_dict
82            .sstable_delta_reader_for_key_range(key_range)?;
83        Ok(Streamer {
84            automaton: self.automaton,
85            states: vec![start_state],
86            delta_reader,
87            key: Vec::new(),
88            term_ord: None,
89            lower_bound: self.lower,
90            upper_bound: self.upper,
91        })
92    }
93}
94
95/// `Streamer` acts as a cursor over a range of terms of a segment.
96/// Terms are guaranteed to be sorted.
97pub struct Streamer<'a, TSSTable, A = AlwaysMatch>
98where
99    A: Automaton,
100    A::State: Clone,
101    TSSTable: SSTable,
102{
103    automaton: A,
104    states: Vec<A::State>,
105    delta_reader: crate::DeltaReader<'a, TSSTable::ValueReader>,
106    key: Vec<u8>,
107    term_ord: Option<TermOrdinal>,
108    lower_bound: Bound<Vec<u8>>,
109    upper_bound: Bound<Vec<u8>>,
110}
111
112impl<'a, TSSTable, A> Streamer<'a, TSSTable, A>
113where
114    A: Automaton,
115    A::State: Clone,
116    TSSTable: SSTable,
117{
118    /// Advance position the stream on the next item.
119    /// Before the first call to `.advance()`, the stream
120    /// is an uninitialized state.
121    pub fn advance(&mut self) -> bool {
122        while self.delta_reader.advance().unwrap() {
123            self.term_ord = Some(
124                self.term_ord
125                    .map(|term_ord| term_ord + 1u64)
126                    .unwrap_or(0u64),
127            );
128            let common_prefix_len = self.delta_reader.common_prefix_len();
129            self.states.truncate(common_prefix_len + 1);
130            self.key.truncate(common_prefix_len);
131            let mut state: A::State = self.states.last().unwrap().clone();
132            for &b in self.delta_reader.suffix() {
133                state = self.automaton.accept(&state, b);
134                self.states.push(state.clone());
135            }
136            self.key.extend_from_slice(self.delta_reader.suffix());
137            let match_lower_bound = match &self.lower_bound {
138                Bound::Unbounded => true,
139                Bound::Included(lower_bound_key) => lower_bound_key[..] <= self.key[..],
140                Bound::Excluded(lower_bound_key) => lower_bound_key[..] < self.key[..],
141            };
142            if !match_lower_bound {
143                continue;
144            }
145            // We match the lower key once. All subsequent keys will pass that bar.
146            self.lower_bound = Bound::Unbounded;
147            let match_upper_bound = match &self.upper_bound {
148                Bound::Unbounded => true,
149                Bound::Included(upper_bound_key) => upper_bound_key[..] >= self.key[..],
150                Bound::Excluded(upper_bound_key) => upper_bound_key[..] > self.key[..],
151            };
152            if !match_upper_bound {
153                return false;
154            }
155            if self.automaton.is_match(&state) {
156                return true;
157            }
158        }
159        false
160    }
161
162    /// Returns the `TermOrdinal` of the given term.
163    ///
164    /// May panic if the called as `.advance()` as never
165    /// been called before.
166    pub fn term_ord(&self) -> TermOrdinal {
167        self.term_ord.unwrap_or(0u64)
168    }
169
170    /// Accesses the current key.
171    ///
172    /// `.key()` should return the key that was returned
173    /// by the `.next()` method.
174    ///
175    /// If the end of the stream as been reached, and `.next()`
176    /// has been called and returned `None`, `.key()` remains
177    /// the value of the last key encountered.
178    ///
179    /// Before any call to `.next()`, `.key()` returns an empty array.
180    pub fn key(&self) -> &[u8] {
181        &self.key
182    }
183
184    /// Accesses the current value.
185    ///
186    /// Calling `.value()` after the end of the stream will return the
187    /// last `.value()` encountered.
188    ///
189    /// # Panics
190    ///
191    /// Calling `.value()` before the first call to `.advance()` returns
192    /// `V::default()`.
193    pub fn value(&self) -> &TSSTable::Value {
194        self.delta_reader.value()
195    }
196
197    /// Return the next `(key, value)` pair.
198    #[allow(clippy::should_implement_trait)]
199    pub fn next(&mut self) -> Option<(&[u8], &TSSTable::Value)> {
200        if self.advance() {
201            Some((self.key(), self.value()))
202        } else {
203            None
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::io;
211
212    use common::OwnedBytes;
213
214    use crate::{Dictionary, MonotonicU64SSTable};
215
216    fn create_test_dictionary() -> io::Result<Dictionary<MonotonicU64SSTable>> {
217        let mut dict_builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new())?;
218        dict_builder.insert(b"abaisance", &0)?;
219        dict_builder.insert(b"abalation", &1)?;
220        dict_builder.insert(b"abalienate", &2)?;
221        dict_builder.insert(b"abandon", &3)?;
222        let buffer = dict_builder.finish()?;
223        let owned_bytes = OwnedBytes::new(buffer);
224        Dictionary::from_bytes(owned_bytes)
225    }
226
227    #[test]
228    fn test_sstable_stream() -> io::Result<()> {
229        let dict = create_test_dictionary()?;
230        let mut streamer = dict.stream()?;
231        assert!(streamer.advance());
232        assert_eq!(streamer.key(), b"abaisance");
233        assert_eq!(streamer.value(), &0);
234        assert!(streamer.advance());
235        assert_eq!(streamer.key(), b"abalation");
236        assert_eq!(streamer.value(), &1);
237        assert!(streamer.advance());
238        assert_eq!(streamer.key(), b"abalienate");
239        assert_eq!(streamer.value(), &2);
240        assert!(streamer.advance());
241        assert_eq!(streamer.key(), b"abandon");
242        assert_eq!(streamer.value(), &3);
243        assert!(!streamer.advance());
244        Ok(())
245    }
246
247    #[test]
248    fn test_sstable_search() -> io::Result<()> {
249        let term_dict = create_test_dictionary()?;
250        let ptn = tantivy_fst::Regex::new("ab.*t.*").unwrap();
251        let mut term_streamer = term_dict.search(ptn).into_stream()?;
252        assert!(term_streamer.advance());
253        assert_eq!(term_streamer.key(), b"abalation");
254        assert_eq!(term_streamer.value(), &1u64);
255        assert!(term_streamer.advance());
256        assert_eq!(term_streamer.key(), b"abalienate");
257        assert_eq!(term_streamer.value(), &2u64);
258        assert!(!term_streamer.advance());
259        Ok(())
260    }
261}