simsearch/
lib.rs

1//! A simple and lightweight fuzzy search engine that works in memory, searching for
2//! similar strings (a pun here).
3//!
4//! # Examples
5//!
6//! ```
7//! use simsearch::SimSearch;
8//!
9//! let mut engine: SimSearch<u32> = SimSearch::new();
10//!
11//! engine.insert(1, "Things Fall Apart");
12//! engine.insert(2, "The Old Man and the Sea");
13//! engine.insert(3, "James Joyce");
14//!
15//! let results: Vec<u32> = engine.search("thngs");
16//!
17//! assert_eq!(results, &[1]);
18//! ```
19//!
20//! By default, Jaro-Winkler distance is used. An alternative Levenshtein distance, which is
21//! SIMD-accelerated but only works for ASCII byte strings, can be specified with `SearchOptions`:
22//!
23//! ```
24//! use simsearch::{SimSearch, SearchOptions};
25//!
26//! let options = SearchOptions::new().levenshtein(true);
27//! let mut engine: SimSearch<u32> = SimSearch::new_with(options);
28//! ```
29
30use std::cmp::{max, Ordering};
31use std::collections::HashMap;
32use std::f64;
33use std::hash::Hash;
34
35use strsim::jaro_winkler;
36use triple_accel::levenshtein::levenshtein_simd_k;
37
38#[cfg(feature = "serde")]
39use serde::{Deserialize, Serialize};
40
41/// The simple search engine.
42#[derive(Debug, Clone)]
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44pub struct SimSearch<Id>
45where
46    Id: Eq + PartialEq + Clone + Hash + Ord,
47{
48    option: SearchOptions,
49    id_num_counter: usize,
50    ids_map: HashMap<Id, usize>,
51    reverse_ids_map: HashMap<usize, Id>,
52    forward_map: HashMap<usize, Vec<String>>,
53    reverse_map: HashMap<String, Vec<usize>>,
54}
55
56impl<Id> SimSearch<Id>
57where
58    Id: Eq + PartialEq + Clone + Hash + Ord,
59{
60    /// Creates search engine with default options.
61    pub fn new() -> Self {
62        Self::new_with(SearchOptions::new())
63    }
64
65    /// Creates search engine with custom options.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use simsearch::{SearchOptions, SimSearch};
71    ///
72    /// let mut engine: SimSearch<usize> = SimSearch::new_with(
73    ///     SearchOptions::new().case_sensitive(true));
74    /// ```
75    pub fn new_with(option: SearchOptions) -> Self {
76        SimSearch {
77            option,
78            id_num_counter: 0,
79            ids_map: HashMap::new(),
80            reverse_ids_map: HashMap::new(),
81            forward_map: HashMap::new(),
82            reverse_map: HashMap::new(),
83        }
84    }
85
86    /// Inserts an entry into search engine.
87    ///
88    /// Input will be tokenized according to the search option.
89    /// By default whitespaces(including tabs) are considered as stop words,
90    /// you can change the behavior by providing `SearchOptions`.
91    ///
92    /// Insert with an existing id updates the content.
93    ///
94    /// **Note that** id is not searchable. Add id to the contents if you would
95    /// like to perform search on it.
96    ///
97    /// Additionally, note that content must be an ASCII string if Levenshtein
98    /// distance is used.
99    ///
100    /// # Examples
101    ///
102    /// ```
103    /// use simsearch::{SearchOptions, SimSearch};
104    ///
105    /// let mut engine: SimSearch<&str> = SimSearch::new_with(
106    ///     SearchOptions::new().stop_words(vec![",".to_string(), ".".to_string()]));
107    ///
108    /// engine.insert("BoJack Horseman", "BoJack Horseman, an American
109    /// adult animated comedy-drama series created by Raphael Bob-Waksberg.
110    /// The series stars Will Arnett as the title character,
111    /// with a supporting cast including Amy Sedaris,
112    /// Alison Brie, Paul F. Tompkins, and Aaron Paul.");
113    /// ```
114    pub fn insert(&mut self, id: Id, content: &str) {
115        self.insert_tokens(id, &[content])
116    }
117
118    /// Inserts entry tokens into search engine.
119    ///
120    /// Search engine also applies tokenizer to the
121    /// provided tokens. Use this method when you have
122    /// special tokenization rules in addition to the built-in ones.
123    ///
124    /// Insert with an existing id updates the content.
125    ///
126    /// **Note that** id is not searchable. Add id to the contents if you would
127    /// like to perform search on it.
128    ///
129    /// Additionally, note that each token must be an ASCII string if Levenshtein
130    /// distance is used.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use simsearch::SimSearch;
136    ///
137    /// let mut engine: SimSearch<&str> = SimSearch::new();
138    ///
139    /// engine.insert_tokens("Arya Stark", &["Arya Stark", "a fictional
140    /// character in American author George R. R", "portrayed by English actress."]);
141    pub fn insert_tokens(&mut self, id: Id, tokens: &[&str]) {
142        self.remove(&id);
143
144        let id_num = self.id_num_counter;
145        self.ids_map.insert(id.clone(), id_num);
146        self.reverse_ids_map.insert(id_num, id);
147        self.id_num_counter += 1;
148
149        let mut tokens = self.tokenize(tokens);
150        tokens.sort();
151
152        for token in tokens.clone() {
153            self.reverse_map
154                .entry(token)
155                .or_insert_with(|| Vec::with_capacity(1))
156                .push(id_num);
157        }
158
159        self.forward_map.insert(id_num, tokens);
160    }
161
162    /// Searches pattern and returns ids sorted by relevance.
163    ///
164    /// Pattern will be tokenized according to the search option.
165    /// By default whitespaces(including tabs) are considered as stop words,
166    /// you can change the behavior by providing `SearchOptions`.
167    ///
168    /// Additionally, note that pattern must be an ASCII string if Levenshtein
169    /// distance is used.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use simsearch::SimSearch;
175    ///
176    /// let mut engine: SimSearch<u32> = SimSearch::new();
177    ///
178    /// engine.insert(1, "Things Fall Apart");
179    /// engine.insert(2, "The Old Man and the Sea");
180    /// engine.insert(3, "James Joyce");
181    ///
182    /// let results: Vec<u32> = engine.search("thngs apa");
183    ///
184    /// assert_eq!(results, &[1]);
185    pub fn search(&self, pattern: &str) -> Vec<Id> {
186        self.search_tokens(&[pattern])
187    }
188
189    /// Searches pattern tokens and returns ids sorted by relevance.
190    ///
191    /// Search engine also applies tokenizer to the
192    /// provided tokens. Use this method when you have
193    /// special tokenization rules in addition to the built-in ones.
194    ///
195    /// Additionally, note that each pattern token must be an ASCII
196    /// string if Levenshtein distance is used.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// use simsearch::SimSearch;
202    ///
203    /// let mut engine: SimSearch<u32> = SimSearch::new();
204    ///
205    /// engine.insert(1, "Things Fall Apart");
206    /// engine.insert(2, "The Old Man and the Sea");
207    /// engine.insert(3, "James Joyce");
208    ///
209    /// let results: Vec<u32> = engine.search_tokens(&["thngs", "apa"]);
210    ///
211    /// assert_eq!(results, &[1]);
212    /// ```
213    pub fn search_tokens(&self, pattern_tokens: &[&str]) -> Vec<Id> {
214        let mut pattern_tokens = self.tokenize(pattern_tokens);
215        pattern_tokens.sort();
216
217        let mut token_scores: HashMap<&str, f64> = HashMap::new();
218
219        for pattern_token in pattern_tokens {
220            for token in self.reverse_map.keys() {
221                let score = if self.option.levenshtein {
222                    let len = max(token.len(), pattern_token.len()) as f64;
223                    // calculate k (based on the threshold) to bound the Levenshtein distance
224                    let k = ((1.0 - self.option.threshold) * len).ceil() as u32;
225                    // levenshtein_simd_k only works on ASCII byte slices, so the token strings
226                    // are directly treated as byte slices
227                    match levenshtein_simd_k(token.as_bytes(), pattern_token.as_bytes(), k) {
228                        Some(dist) => 1.0 - if len == 0.0 { 0.0 } else { (dist as f64) / len },
229                        None => f64::MIN,
230                    }
231                } else {
232                    jaro_winkler(token, &pattern_token)
233                };
234
235                if score > self.option.threshold {
236                    token_scores.insert(token, score);
237                }
238            }
239        }
240
241        let mut result_scores: HashMap<usize, f64> = HashMap::new();
242
243        for (token, score) in token_scores.drain() {
244            for id_num in &self.reverse_map[token] {
245                *result_scores.entry(*id_num).or_insert(0.) += score;
246            }
247        }
248
249        let mut result_scores: Vec<(f64, Id)> = result_scores
250            .drain()
251            .map(|(id_num, score)| {
252                let id = self
253                    .reverse_ids_map
254                    .get(&id_num)
255                    // this can go wrong only if something (e.g. delete) leaves us in an
256                    // inconsistent state
257                    .expect("id at id_num should be there")
258                    .to_owned();
259                (score, id)
260            })
261            .collect();
262
263        result_scores.sort_by(|(lhs_score, lhs_id), (rhs_score, rhs_id)| {
264            match rhs_score.partial_cmp(lhs_score).unwrap() {
265                Ordering::Equal => lhs_id.cmp(rhs_id),
266                ord => ord,
267            }
268        });
269
270        let result_ids: Vec<Id> = result_scores.into_iter().map(|(_, id)| id).collect();
271
272        result_ids
273    }
274
275    /// Remove an entry by id.
276    pub fn remove(&mut self, id: &Id) {
277        if let Some(id_num) = self.ids_map.get(id) {
278            for token in &self.forward_map[id_num] {
279                self.reverse_map
280                    .get_mut(token)
281                    .unwrap()
282                    .retain(|i| i != id_num);
283            }
284            self.forward_map.remove(id_num);
285            self.reverse_ids_map.remove(id_num);
286            self.ids_map.remove(id);
287        };
288    }
289
290    /// Clear all entries.
291    pub fn clear(&mut self) {
292        self.id_num_counter = 0;
293        self.ids_map.clear();
294        self.reverse_ids_map.clear();
295        self.forward_map.clear();
296        self.reverse_map.clear();
297    }
298
299    fn tokenize(&self, tokens: &[&str]) -> Vec<String> {
300        let tokens: Vec<String> = tokens
301            .iter()
302            .map(|token| {
303                if self.option.case_sensitive {
304                    token.to_string()
305                } else {
306                    token.to_lowercase()
307                }
308            })
309            .collect();
310
311        let mut tokens: Vec<String> = if self.option.stop_whitespace {
312            tokens
313                .iter()
314                .flat_map(|token| token.split_whitespace())
315                .map(|token| token.to_string())
316                .collect()
317        } else {
318            tokens
319        };
320
321        for stop_word in &self.option.stop_words {
322            tokens = tokens
323                .iter()
324                .flat_map(|token| token.split_terminator(stop_word.as_str()))
325                .map(|token| token.to_string())
326                .collect();
327        }
328
329        tokens.retain(|token| !token.is_empty());
330
331        tokens
332    }
333}
334
335/// Options and flags that configuring the search engine.
336///
337/// # Examples
338///
339/// ```
340/// use simsearch::{SearchOptions, SimSearch};
341///
342/// let mut engine: SimSearch<usize> = SimSearch::new_with(
343///     SearchOptions::new().case_sensitive(true));
344/// ```
345#[derive(Debug, Clone)]
346#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
347pub struct SearchOptions {
348    case_sensitive: bool,
349    stop_whitespace: bool,
350    stop_words: Vec<String>,
351    threshold: f64,
352    levenshtein: bool,
353}
354
355impl SearchOptions {
356    /// Creates a default configuration.
357    pub fn new() -> Self {
358        SearchOptions {
359            case_sensitive: false,
360            stop_whitespace: true,
361            stop_words: vec![],
362            threshold: 0.8,
363            levenshtein: false,
364        }
365    }
366
367    /// Sets whether search engine is case sensitive or not.
368    ///
369    /// Defaults to `false`.
370    pub fn case_sensitive(self, case_sensitive: bool) -> Self {
371        SearchOptions {
372            case_sensitive,
373            ..self
374        }
375    }
376
377    /// Sets the whether search engine splits tokens on whitespace or not.
378    /// The **whitespace** here includes tab, returns and so forth.
379    ///
380    /// See also [`std::str::split_whitespace()`](https://doc.rust-lang.org/std/primitive.str.html#method.split_whitespace).
381    ///
382    /// Defaults to `true`.
383    pub fn stop_whitespace(self, stop_whitespace: bool) -> Self {
384        SearchOptions {
385            stop_whitespace,
386            ..self
387        }
388    }
389
390    /// Sets the custom token stop word.
391    ///
392    /// This option enables tokenizer to split contents
393    /// and search words by the extra list of custom stop words.
394    ///
395    /// Defaults to `&[]`.
396    ///
397    /// # Examples
398    /// ```
399    /// use simsearch::{SearchOptions, SimSearch};
400    ///
401    /// let mut engine: SimSearch<usize> = SimSearch::new_with(
402    ///     SearchOptions::new().stop_words(vec!["/".to_string(), "\\".to_string()]));
403    ///
404    /// engine.insert(1, "the old/man/and/the sea");
405    ///
406    /// let results = engine.search("old");
407    ///
408    /// assert_eq!(results, &[1]);
409    /// ```
410    pub fn stop_words(self, stop_words: Vec<String>) -> Self {
411        SearchOptions { stop_words, ..self }
412    }
413
414    /// Sets the threshold for search scoring.
415    ///
416    /// Search results will be sorted by their Jaro winkler similarity scores.
417    /// Scores ranges from 0 to 1 where the 1 indicates the most relevant.
418    /// Only the entries with scores greater than the threshold will be returned.
419    ///
420    /// Defaults to `0.8`.
421    pub fn threshold(self, threshold: f64) -> Self {
422        SearchOptions { threshold, ..self }
423    }
424
425    /// Sets whether Levenshtein distance, which is SIMD-accelerated, should be
426    /// used instead of the default Jaro-Winkler distance.
427    ///
428    /// The implementation of Levenshtein distance is very fast but cannot handle Unicode
429    /// strings, unlike the default Jaro-Winkler distance. The strings are treated as byte
430    /// slices with Levenshtein distance, which means that the calculated score may be
431    /// incorrectly lower for Unicode strings, where each character is represented with
432    /// multiple bytes.
433    ///
434    /// Defaults to `false`.
435    pub fn levenshtein(self, levenshtein: bool) -> Self {
436        SearchOptions {
437            levenshtein,
438            ..self
439        }
440    }
441}
442
443impl<Id> Default for SimSearch<Id>
444where
445    Id: Eq + PartialEq + Clone + Hash + Ord,
446{
447    fn default() -> Self {
448        SimSearch::new()
449    }
450}
451
452impl Default for SearchOptions {
453    fn default() -> Self {
454        SearchOptions::new()
455    }
456}