yass/
levenshtein.rs

1use super::{ExpectTokenizerType, StrSim, TokenizerType};
2use crate::error::StrSimError;
3use anyhow::Result;
4use derive_more::Display;
5use hashbrown::HashMap;
6
7#[derive(Display)]
8#[display(fmt = "Levenshtein")]
9pub struct Levenshtein {
10    pub insertion: HashMap<char, f64>,
11    pub insertion_default: f64,
12    pub deletion: HashMap<char, f64>,
13    pub deletion_default: f64,
14    pub substitution: HashMap<char, HashMap<char, f64>>,
15    pub substitution_default: f64,
16    pub lowerbound: f64,
17}
18
19impl Levenshtein {
20    pub fn default() -> Self {
21        Levenshtein {
22            insertion: HashMap::new(),
23            insertion_default: 1.0,
24            deletion: HashMap::new(),
25            deletion_default: 1.0,
26            substitution: HashMap::new(),
27            substitution_default: 1.0,
28            lowerbound: -1.0,
29        }
30    }
31
32    pub fn compute_max_cost(&self, chars: &[char]) -> f64 {
33        chars
34            .iter()
35            .map(|c| {
36                self.insertion
37                    .get(c)
38                    .unwrap_or(&self.insertion_default)
39                    .max(
40                        self.deletion.get(c).unwrap_or(&self.deletion_default).max(
41                            *self
42                                .substitution
43                                .get(c)
44                                // RLTK has bug here, I haven't verified my fix
45                                .map(|subs| {
46                                    subs.values()
47                                        .max_by(|&a, &b| a.partial_cmp(b).unwrap())
48                                        .unwrap()
49                                })
50                                .unwrap_or(&self.substitution_default),
51                        ),
52                    )
53            })
54            .sum()
55    }
56
57    pub fn estimate_min_char_cost(&self, chars: &[char]) -> f64 {
58        chars
59            .iter()
60            .map(|c| {
61                self.insertion
62                    .get(c)
63                    .unwrap_or(&self.insertion_default)
64                    .min(
65                        self.deletion.get(c).unwrap_or(&self.deletion_default).min(
66                            *self
67                                .substitution
68                                .get(c)
69                                // RLTK has bug here, I haven't verified my fix
70                                .map(|subs| {
71                                    subs.values()
72                                        .max_by(|&a, &b| a.partial_cmp(b).unwrap())
73                                        .unwrap()
74                                })
75                                .unwrap_or(&self.substitution_default),
76                        ),
77                    )
78            })
79            .min_by(|a, b| a.partial_cmp(b).unwrap())
80            .unwrap()
81    }
82
83    /// The Levenshtein distance between two words is the minimum number of single-character edits (insertions,
84    /// deletions or substitutions) required to change one word into the other.
85    pub fn distance(&self, s1: &[char], s2: &[char]) -> f64 {
86        let n1 = s1.len();
87        let n2 = s2.len();
88        if n1 == 0 && n2 == 0 {
89            return 0.0;
90        }
91
92        let mut dp: Vec<Vec<f64>> = vec![vec![0.0; n2 + 1]; n1 + 1];
93        for i in 0..=n1 {
94            for j in 0..=n2 {
95                if i == 0 && j == 0 {
96                    continue;
97                }
98
99                if i == 0 {
100                    // most top row
101                    let c = &s2[j - 1];
102                    dp[i][j] = *self.insertion.get(c).unwrap_or(&self.insertion_default);
103                    dp[i][j] += dp[i][j - 1];
104                } else if j == 0 {
105                    // most left column
106                    let c = &s1[i - 1];
107                    dp[i][j] = *self.deletion.get(c).unwrap_or(&self.deletion_default);
108                    dp[i][j] += dp[i - 1][j];
109                } else {
110                    let c1 = &s1[i - 1];
111                    let c2 = &s2[j - 1];
112                    let insert_cost = self.insertion.get(c2).unwrap_or(&self.insertion_default);
113                    let delete_cost = self.deletion.get(c1).unwrap_or(&self.deletion_default);
114                    let substitute_cost = self
115                        .substitution
116                        .get(c1)
117                        .map(|subs| subs.get(c2).unwrap_or(&self.substitution_default))
118                        .unwrap_or(&self.substitution_default);
119
120                    if c1 == c2 {
121                        dp[i][j] = dp[i - 1][j - 1];
122                    } else {
123                        dp[i][j] = (dp[i][j - 1] + insert_cost).min(
124                            (dp[i - 1][j] + delete_cost).min(dp[i - 1][j - 1] + substitute_cost),
125                        );
126                    }
127                }
128            }
129        }
130        return dp[n1][n2];
131    }
132
133    /**
134     * Compute the Levenshtein similarity between two strings as
135     * 1 - (levenshtein_distance / max_cost(key, query)).
136     *
137     * Directly translated from RLTK's implementation.
138     */
139    pub fn similarity(&self, s1: &[char], s2: &[char]) -> Result<f64, StrSimError> {
140        let max_cost = self.compute_max_cost(&s1).max(self.compute_max_cost(&s2)) as f64;
141        let min_lev: f64;
142
143        if self.lowerbound > 0.0 {
144            let diff = s1.len().abs_diff(s2.len()) as f64;
145            if s1.len() == 0 && s2.len() == 0 {
146                return Ok(1.0);
147            }
148            if s1.len() == 0 {
149                min_lev = diff * self.estimate_min_char_cost(&s2) as f64;
150            } else if s2.len() == 0 {
151                min_lev = diff * self.estimate_min_char_cost(&s1) as f64;
152            } else {
153                min_lev = diff
154                    * self
155                        .estimate_min_char_cost(&s1)
156                        .min(self.estimate_min_char_cost(&s2)) as f64;
157            }
158            let est_sim = 1.0 - (min_lev / max_cost);
159            if est_sim < self.lowerbound {
160                return Ok(0.0);
161            }
162        }
163
164        let lev = self.distance(&s1, &s2) as f64;
165        if max_cost < lev {
166            return Err(StrSimError::InvalidConfigData(
167                "Illegal value of operation costs".to_owned(),
168            ));
169        }
170
171        if max_cost == 0.0 {
172            return Ok(1.0);
173        }
174
175        let lev_sim = 1.0 - (lev / max_cost);
176        if self.lowerbound > 0.0 && lev_sim < self.lowerbound {
177            return Ok(0.0);
178        }
179        Ok(lev_sim)
180    }
181}
182
183impl ExpectTokenizerType for Levenshtein {
184    fn get_expected_tokenizer_type(&self) -> TokenizerType {
185        TokenizerType::Seq(Box::new(None))
186    }
187}
188
189impl StrSim<Vec<char>> for Levenshtein {
190    fn similarity_pre_tok2(&self, s1: &Vec<char>, s2: &Vec<char>) -> Result<f64, StrSimError> {
191        self.similarity(s1, s2)
192    }
193}