trie_rs/
inc_search.rs

1//! Incremental search
2//!
3//! # Motivation
4//!
5//! The motivation for this struct is for "online" or interactive use cases. One
6//! often accumulates input to match against a trie. Using the standard
7//! [`exact_match()`][crate::trie::Trie::exact_match] faculties which has a time
8//! complexity of _O(m log n)_ where _m_ is the query string length and _n_ is
9//! the number of entries in the trie. Consider this loop where we simulate
10//! accumulating a query.
11//!
12//! ```rust
13//! use trie_rs::Trie;
14//!
15//! let q = "appli"; // query string
16//! let mut is_match: bool;
17//! let trie = Trie::from_iter(vec!["appli", "application"]);
18//! for i in 0..q.len() - 1 {
19//!     assert!(!trie.exact_match(&q[0..i]));
20//! }
21//! assert!(trie.exact_match(q));
22//! ```
23//!
24//! Building the query one "character" at a time and `exact_match()`ing each
25//! time, the loop has effectively complexity of _O(m<sup>2</sup> log n)_.
26//!
27//! Using the incremental search, the time complexity of each query is _O(log
28//! n)_ which returns an [Answer] enum.
29//!
30//! ```ignore
31//! let q = "appli"; // query string
32//! let inc_search = trie.inc_search();
33//! let mut is_match: bool;
34//! for i = 0..q.len() {
35//!     is_match = inc_search.query(q[i]).unwrap().is_match();
36//! }
37//! ```
38//!
39//! This means the above code restores the time complexity of _O(m log n)_ for
40//! the loop.
41use crate::{
42    map::Trie,
43    try_collect::{TryCollect, TryFromIterator},
44};
45use louds_rs::LoudsNodeNum;
46
47#[derive(Debug, Clone)]
48/// An incremental search of the trie.
49pub struct IncSearch<'a, Label, Value> {
50    trie: &'a Trie<Label, Value>,
51    node: LoudsNodeNum,
52}
53
54/// Search position in the trie.
55///
56/// # Why do this?
57///
58/// "Position" is more descriptive for incremental search purposes, and without
59/// it a user would have to explicitly depend on `louds-rs`.
60pub type Position = LoudsNodeNum;
61
62/// Retrieve the position the search is on. Useful for hanging on to a search
63/// without having to fight the borrow checker because its borrowing a trie.
64impl<'a, L, V> From<IncSearch<'a, L, V>> for Position {
65    fn from(inc_search: IncSearch<'a, L, V>) -> Self {
66        inc_search.node
67    }
68}
69
70/// A "matching" answer to an incremental search on a partial query.
71#[derive(Debug, PartialEq, Eq, Clone, Copy)]
72pub enum Answer {
73    /// There is a prefix here.
74    Prefix,
75    /// There is an exact match here.
76    Match,
77    /// There is a prefix and an exact match here.
78    PrefixAndMatch,
79}
80
81impl Answer {
82    /// Is query answer a prefix?
83    pub fn is_prefix(&self) -> bool {
84        matches!(self, Answer::Prefix | Answer::PrefixAndMatch)
85    }
86
87    /// Is query answer an exact match?
88    pub fn is_match(&self) -> bool {
89        matches!(self, Answer::Match | Answer::PrefixAndMatch)
90    }
91
92    fn new(is_prefix: bool, is_match: bool) -> Option<Self> {
93        match (is_prefix, is_match) {
94            (true, false) => Some(Answer::Prefix),
95            (false, true) => Some(Answer::Match),
96            (true, true) => Some(Answer::PrefixAndMatch),
97            (false, false) => None,
98        }
99    }
100}
101
102impl<'a, Label: Ord, Value> IncSearch<'a, Label, Value> {
103    /// Create a new incremental search for a trie.
104    pub fn new(trie: &'a Trie<Label, Value>) -> Self {
105        Self {
106            trie,
107            node: LoudsNodeNum(1),
108        }
109    }
110
111    /// Resume an incremental search at a particular point.
112    ///
113    /// ```
114    /// use trie_rs::{Trie, inc_search::{Answer, IncSearch}};
115    /// use louds_rs::LoudsNodeNum;
116    ///
117    /// let trie: Trie<u8> = ["hello", "bye"].into_iter().collect();
118    /// let mut inc_search = trie.inc_search();
119    ///
120    /// assert_eq!(inc_search.query_until("he"), Ok(Answer::Prefix));
121    /// let position = LoudsNodeNum::from(inc_search);
122    ///
123    /// // inc_search is dropped.
124    /// let mut inc_search2 = IncSearch::resume(&trie.0, position);
125    /// assert_eq!(inc_search2.query_until("llo"), Ok(Answer::Match));
126    ///
127    /// ```
128    pub fn resume(trie: &'a Trie<Label, Value>, position: Position) -> Self {
129        Self {
130            trie,
131            node: position,
132        }
133    }
134
135    /// Query but do not change the node we're looking at on the trie.
136    pub fn peek(&self, chr: &Label) -> Option<Answer> {
137        let children_node_nums: Vec<_> = self.trie.children_node_nums(self.node).collect();
138        let res = self
139            .trie
140            .bin_search_by_children_labels(chr, &children_node_nums[..]);
141        match res {
142            Ok(j) => {
143                let node = children_node_nums[j];
144                let is_prefix = self.trie.has_children_node_nums(node);
145                let is_match = self.trie.value(node).is_some();
146                Answer::new(is_prefix, is_match)
147            }
148            Err(_) => None,
149        }
150    }
151
152    /// Query the trie and go to node if there is a match.
153    pub fn query(&mut self, chr: &Label) -> Option<Answer> {
154        let children_node_nums: Vec<_> = self.trie.children_node_nums(self.node).collect();
155        let res = self
156            .trie
157            .bin_search_by_children_labels(chr, &children_node_nums[..]);
158        match res {
159            Ok(j) => {
160                self.node = children_node_nums[j];
161                let is_prefix = self.trie.has_children_node_nums(self.node);
162                let is_match = self.trie.value(self.node).is_some();
163                Answer::new(is_prefix, is_match)
164            }
165            Err(_) => None,
166        }
167    }
168
169    /// Query the trie with a sequence. Will return `Err(index of query)` on
170    /// first failure to match.
171    pub fn query_until(&mut self, query: impl AsRef<[Label]>) -> Result<Answer, usize> {
172        let mut result = None;
173        let mut i = 0;
174        for chr in query.as_ref().iter() {
175            result = self.query(chr);
176            if result.is_none() {
177                return Err(i);
178            }
179            i += 1;
180        }
181        result.ok_or(i)
182    }
183
184    /// Return the value at current node. There should be one for any node where
185    /// `answer.is_match()` is true.
186    pub fn value(&self) -> Option<&'a Value> {
187        self.trie.value(self.node)
188    }
189
190    /// Go to the longest shared prefix.
191    pub fn goto_longest_prefix(&mut self) -> Result<usize, usize> {
192        let mut count = 0;
193
194        while count == 0 || !self.trie.is_terminal(self.node) {
195            let mut iter = self.trie.children_node_nums(self.node);
196            let first = iter.next();
197            let second = iter.next();
198            match (first, second) {
199                (Some(child_node_num), None) => {
200                    self.node = child_node_num;
201                    count += 1;
202                }
203                (None, _) => {
204                    assert_eq!(count, 0);
205                    return Ok(count);
206                }
207                _ => {
208                    return Err(count);
209                }
210            }
211        }
212        Ok(count)
213    }
214
215    /// Return the current prefix for this search.
216    pub fn prefix<C, M>(&self) -> C
217    where
218        C: TryFromIterator<Label, M>,
219        Label: Clone,
220    {
221        let mut v: Vec<Label> = self
222            .trie
223            .child_to_ancestors(self.node)
224            .map(|node| self.trie.label(node).clone())
225            .collect();
226        v.reverse();
227        v.into_iter().try_collect().expect("Could not collect")
228    }
229
230    /// Returne the length of the current prefix for this search.
231    pub fn prefix_len(&self) -> usize {
232        // TODO: If PR for child_to_ancestors is accepted. Use the iterator and
233        // remove `pub(crate)` from Trie.louds field. Also uncomment prefix()
234        // above.
235
236        self.trie.child_to_ancestors(self.node).count()
237
238        // let mut node = self.node;
239        // let mut count = 0;
240        // while node.0 > 1 {
241        //     let index = self.trie.louds.node_num_to_index(node);
242        //     node = self.trie.louds.child_to_parent(index);
243        //     count += 1;
244        // }
245        // count
246    }
247
248    // This isn't actually possible.
249    // /// Return the mutable value at current node. There should be one for any
250    // /// node where `answer.is_match()` is true.
251    // ///
252    // /// Note: Because [IncSearch] does not store a mutable reference to the
253    // /// trie, a mutable reference must be provided.
254    // pub fn value_mut<'b>(self, trie: &'b mut Trie<Label, Value>) -> Option<&'b mut Value> {
255    //     trie.value_mut(self.node)
256    // }
257
258    /// Reset the query.
259    pub fn reset(&mut self) {
260        self.node = LoudsNodeNum(1);
261    }
262}
263
264#[cfg(test)]
265mod search_tests {
266    use super::*;
267    use crate::map::{Trie, TrieBuilder};
268
269    fn build_trie() -> Trie<u8, u8> {
270        let mut builder = TrieBuilder::new();
271        builder.push("a", 0);
272        builder.push("app", 1);
273        builder.push("apple", 2);
274        builder.push("better", 3);
275        builder.push("application", 4);
276        builder.push("アップル🍎", 5);
277        builder.build()
278    }
279
280    #[test]
281    fn inc_search() {
282        let trie = build_trie();
283        let mut search = trie.inc_search();
284        assert_eq!("", search.prefix::<String, _>());
285        assert_eq!(0, search.prefix_len());
286        assert_eq!(None, search.query(&b'z'));
287        assert_eq!("", search.prefix::<String, _>());
288        assert_eq!(0, search.prefix_len());
289        assert_eq!(Answer::PrefixAndMatch, search.query(&b'a').unwrap());
290        assert_eq!("a", search.prefix::<String, _>());
291        assert_eq!(1, search.prefix_len());
292        assert_eq!(Answer::Prefix, search.query(&b'p').unwrap());
293        assert_eq!("ap", search.prefix::<String, _>());
294        assert_eq!(2, search.prefix_len());
295        assert_eq!(Answer::PrefixAndMatch, search.query(&b'p').unwrap());
296        assert_eq!("app", search.prefix::<String, _>());
297        assert_eq!(3, search.prefix_len());
298        assert_eq!(Answer::Prefix, search.query(&b'l').unwrap());
299        assert_eq!("appl", search.prefix::<String, _>());
300        assert_eq!(4, search.prefix_len());
301        assert_eq!(Answer::Match, search.query(&b'e').unwrap());
302        assert_eq!("apple", search.prefix::<String, _>());
303        assert_eq!(5, search.prefix_len());
304    }
305
306    #[test]
307    fn inc_search_value() {
308        let trie = build_trie();
309        let mut search = trie.inc_search();
310        assert_eq!("", search.prefix::<String, _>());
311        assert_eq!(None, search.query(&b'z'));
312        assert_eq!("", search.prefix::<String, _>());
313        assert_eq!(Answer::PrefixAndMatch, search.query(&b'a').unwrap());
314        assert_eq!("a", search.prefix::<String, _>());
315        assert_eq!(Answer::Prefix, search.query(&b'p').unwrap());
316        assert_eq!("ap", search.prefix::<String, _>());
317        assert_eq!(Answer::PrefixAndMatch, search.query(&b'p').unwrap());
318        assert_eq!("app", search.prefix::<String, _>());
319        assert_eq!(Answer::Prefix, search.query(&b'l').unwrap());
320        assert_eq!("appl", search.prefix::<String, _>());
321        assert_eq!(Answer::Match, search.query(&b'e').unwrap());
322        assert_eq!("apple", search.prefix::<String, _>());
323        assert_eq!(Some(&2), search.value());
324    }
325
326    #[test]
327    fn inc_search_query_until() {
328        let trie = build_trie();
329        let mut search = trie.inc_search();
330        assert_eq!(Err(0), search.query_until("zoo"));
331        assert_eq!("", search.prefix::<String, _>());
332        search.reset();
333        assert_eq!(Err(1), search.query_until("blue"));
334        assert_eq!("b", search.prefix::<String, _>());
335        search.reset();
336        assert_eq!(Answer::Match, search.query_until("apple").unwrap());
337        assert_eq!("apple", search.prefix::<String, _>());
338        assert_eq!(Some(&2), search.value());
339    }
340
341    #[test]
342    fn inc_search_goto_longest_prefix() {
343        let trie = build_trie();
344        let mut search = trie.inc_search();
345        assert_eq!(Err(0), search.goto_longest_prefix());
346        assert_eq!("", search.prefix::<String, _>());
347        search.reset();
348        assert_eq!(Ok(Answer::PrefixAndMatch), search.query_until("a"));
349        assert_eq!("a", search.prefix::<String, _>());
350        assert_eq!(Ok(2), search.goto_longest_prefix());
351        assert_eq!("app", search.prefix::<String, _>());
352        assert_eq!(Err(1), search.goto_longest_prefix());
353        assert_eq!("appl", search.prefix::<String, _>());
354        assert_eq!(Err(0), search.goto_longest_prefix());
355        assert_eq!(Ok(Answer::Prefix), search.query_until("i"));
356        assert_eq!(Ok(6), search.goto_longest_prefix());
357        assert_eq!(Ok(0), search.goto_longest_prefix());
358        assert_eq!("application", search.prefix::<String, _>());
359        search.reset();
360        assert_eq!(Answer::Match, search.query_until("apple").unwrap());
361        assert_eq!("apple", search.prefix::<String, _>());
362        assert_eq!(Some(&2), search.value());
363    }
364
365    // #[test]
366    // fn inc_serach_value_mut() {
367    //     let trie = build_trie();
368    //     let mut search = trie.inc_search();
369    //     assert_eq!(None, search.query(b'z'));
370    //     assert_eq!(Answer::PrefixAndMatch, search.query(b'a').unwrap());
371    //     assert_eq!(Answer::Prefix, search.query(b'p').unwrap());
372    //     assert_eq!(Answer::PrefixAndMatch, search.query(b'p').unwrap());
373    //     assert_eq!(Answer::Prefix, search.query(b'l').unwrap());
374    //     assert_eq!(Answer::Match, search.query(b'e').unwrap());
375    //     let mut v = search.value_mut(&mut trie);
376    //     assert_eq!(Some(&2), v.as_deref())
377    // }
378}