1use std::io;
2use std::ops::Bound;
3
4use tantivy_fst::automaton::AlwaysMatch;
5use tantivy_fst::Automaton;
6
7use crate::dictionary::Dictionary;
8use crate::{SSTable, TermOrdinal};
9
10pub struct StreamerBuilder<'a, TSSTable, A = AlwaysMatch>
13where
14 A: Automaton,
15 A::State: Clone,
16 TSSTable: SSTable,
17{
18 term_dict: &'a Dictionary<TSSTable>,
19 automaton: A,
20 lower: Bound<Vec<u8>>,
21 upper: Bound<Vec<u8>>,
22}
23
24fn bound_as_byte_slice(bound: &Bound<Vec<u8>>) -> Bound<&[u8]> {
25 match bound.as_ref() {
26 Bound::Included(key) => Bound::Included(key.as_slice()),
27 Bound::Excluded(key) => Bound::Excluded(key.as_slice()),
28 Bound::Unbounded => Bound::Unbounded,
29 }
30}
31
32impl<'a, TSSTable, A> StreamerBuilder<'a, TSSTable, A>
33where
34 A: Automaton,
35 A::State: Clone,
36 TSSTable: SSTable,
37{
38 pub(crate) fn new(term_dict: &'a Dictionary<TSSTable>, automaton: A) -> Self {
39 StreamerBuilder {
40 term_dict,
41 automaton,
42 lower: Bound::Unbounded,
43 upper: Bound::Unbounded,
44 }
45 }
46
47 pub fn ge<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
49 self.lower = Bound::Included(bound.as_ref().to_owned());
50 self
51 }
52
53 pub fn gt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
55 self.lower = Bound::Excluded(bound.as_ref().to_owned());
56 self
57 }
58
59 pub fn le<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
61 self.upper = Bound::Included(bound.as_ref().to_owned());
62 self
63 }
64
65 pub fn lt<T: AsRef<[u8]>>(mut self, bound: T) -> Self {
67 self.upper = Bound::Excluded(bound.as_ref().to_owned());
68 self
69 }
70
71 pub fn into_stream(self) -> io::Result<Streamer<'a, TSSTable, A>> {
74 let start_state = self.automaton.start();
76 let key_range = (
77 bound_as_byte_slice(&self.lower),
78 bound_as_byte_slice(&self.upper),
79 );
80 let delta_reader = self
81 .term_dict
82 .sstable_delta_reader_for_key_range(key_range)?;
83 Ok(Streamer {
84 automaton: self.automaton,
85 states: vec![start_state],
86 delta_reader,
87 key: Vec::new(),
88 term_ord: None,
89 lower_bound: self.lower,
90 upper_bound: self.upper,
91 })
92 }
93}
94
95pub struct Streamer<'a, TSSTable, A = AlwaysMatch>
98where
99 A: Automaton,
100 A::State: Clone,
101 TSSTable: SSTable,
102{
103 automaton: A,
104 states: Vec<A::State>,
105 delta_reader: crate::DeltaReader<'a, TSSTable::ValueReader>,
106 key: Vec<u8>,
107 term_ord: Option<TermOrdinal>,
108 lower_bound: Bound<Vec<u8>>,
109 upper_bound: Bound<Vec<u8>>,
110}
111
112impl<'a, TSSTable, A> Streamer<'a, TSSTable, A>
113where
114 A: Automaton,
115 A::State: Clone,
116 TSSTable: SSTable,
117{
118 pub fn advance(&mut self) -> bool {
122 while self.delta_reader.advance().unwrap() {
123 self.term_ord = Some(
124 self.term_ord
125 .map(|term_ord| term_ord + 1u64)
126 .unwrap_or(0u64),
127 );
128 let common_prefix_len = self.delta_reader.common_prefix_len();
129 self.states.truncate(common_prefix_len + 1);
130 self.key.truncate(common_prefix_len);
131 let mut state: A::State = self.states.last().unwrap().clone();
132 for &b in self.delta_reader.suffix() {
133 state = self.automaton.accept(&state, b);
134 self.states.push(state.clone());
135 }
136 self.key.extend_from_slice(self.delta_reader.suffix());
137 let match_lower_bound = match &self.lower_bound {
138 Bound::Unbounded => true,
139 Bound::Included(lower_bound_key) => lower_bound_key[..] <= self.key[..],
140 Bound::Excluded(lower_bound_key) => lower_bound_key[..] < self.key[..],
141 };
142 if !match_lower_bound {
143 continue;
144 }
145 self.lower_bound = Bound::Unbounded;
147 let match_upper_bound = match &self.upper_bound {
148 Bound::Unbounded => true,
149 Bound::Included(upper_bound_key) => upper_bound_key[..] >= self.key[..],
150 Bound::Excluded(upper_bound_key) => upper_bound_key[..] > self.key[..],
151 };
152 if !match_upper_bound {
153 return false;
154 }
155 if self.automaton.is_match(&state) {
156 return true;
157 }
158 }
159 false
160 }
161
162 pub fn term_ord(&self) -> TermOrdinal {
167 self.term_ord.unwrap_or(0u64)
168 }
169
170 pub fn key(&self) -> &[u8] {
181 &self.key
182 }
183
184 pub fn value(&self) -> &TSSTable::Value {
194 self.delta_reader.value()
195 }
196
197 #[allow(clippy::should_implement_trait)]
199 pub fn next(&mut self) -> Option<(&[u8], &TSSTable::Value)> {
200 if self.advance() {
201 Some((self.key(), self.value()))
202 } else {
203 None
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::io;
211
212 use common::OwnedBytes;
213
214 use crate::{Dictionary, MonotonicU64SSTable};
215
216 fn create_test_dictionary() -> io::Result<Dictionary<MonotonicU64SSTable>> {
217 let mut dict_builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new())?;
218 dict_builder.insert(b"abaisance", &0)?;
219 dict_builder.insert(b"abalation", &1)?;
220 dict_builder.insert(b"abalienate", &2)?;
221 dict_builder.insert(b"abandon", &3)?;
222 let buffer = dict_builder.finish()?;
223 let owned_bytes = OwnedBytes::new(buffer);
224 Dictionary::from_bytes(owned_bytes)
225 }
226
227 #[test]
228 fn test_sstable_stream() -> io::Result<()> {
229 let dict = create_test_dictionary()?;
230 let mut streamer = dict.stream()?;
231 assert!(streamer.advance());
232 assert_eq!(streamer.key(), b"abaisance");
233 assert_eq!(streamer.value(), &0);
234 assert!(streamer.advance());
235 assert_eq!(streamer.key(), b"abalation");
236 assert_eq!(streamer.value(), &1);
237 assert!(streamer.advance());
238 assert_eq!(streamer.key(), b"abalienate");
239 assert_eq!(streamer.value(), &2);
240 assert!(streamer.advance());
241 assert_eq!(streamer.key(), b"abandon");
242 assert_eq!(streamer.value(), &3);
243 assert!(!streamer.advance());
244 Ok(())
245 }
246
247 #[test]
248 fn test_sstable_search() -> io::Result<()> {
249 let term_dict = create_test_dictionary()?;
250 let ptn = tantivy_fst::Regex::new("ab.*t.*").unwrap();
251 let mut term_streamer = term_dict.search(ptn).into_stream()?;
252 assert!(term_streamer.advance());
253 assert_eq!(term_streamer.key(), b"abalation");
254 assert_eq!(term_streamer.value(), &1u64);
255 assert!(term_streamer.advance());
256 assert_eq!(term_streamer.key(), b"abalienate");
257 assert_eq!(term_streamer.value(), &2u64);
258 assert!(!term_streamer.advance());
259 Ok(())
260 }
261}