1use super::{ExpectTokenizerType, StrSim, TokenizerType};
2use crate::error::StrSimError;
3use anyhow::Result;
4use derive_more::Display;
5use hashbrown::HashMap;
6
7#[derive(Display)]
8#[display(fmt = "Levenshtein")]
9pub struct Levenshtein {
10 pub insertion: HashMap<char, f64>,
11 pub insertion_default: f64,
12 pub deletion: HashMap<char, f64>,
13 pub deletion_default: f64,
14 pub substitution: HashMap<char, HashMap<char, f64>>,
15 pub substitution_default: f64,
16 pub lowerbound: f64,
17}
18
19impl Levenshtein {
20 pub fn default() -> Self {
21 Levenshtein {
22 insertion: HashMap::new(),
23 insertion_default: 1.0,
24 deletion: HashMap::new(),
25 deletion_default: 1.0,
26 substitution: HashMap::new(),
27 substitution_default: 1.0,
28 lowerbound: -1.0,
29 }
30 }
31
32 pub fn compute_max_cost(&self, chars: &[char]) -> f64 {
33 chars
34 .iter()
35 .map(|c| {
36 self.insertion
37 .get(c)
38 .unwrap_or(&self.insertion_default)
39 .max(
40 self.deletion.get(c).unwrap_or(&self.deletion_default).max(
41 *self
42 .substitution
43 .get(c)
44 .map(|subs| {
46 subs.values()
47 .max_by(|&a, &b| a.partial_cmp(b).unwrap())
48 .unwrap()
49 })
50 .unwrap_or(&self.substitution_default),
51 ),
52 )
53 })
54 .sum()
55 }
56
57 pub fn estimate_min_char_cost(&self, chars: &[char]) -> f64 {
58 chars
59 .iter()
60 .map(|c| {
61 self.insertion
62 .get(c)
63 .unwrap_or(&self.insertion_default)
64 .min(
65 self.deletion.get(c).unwrap_or(&self.deletion_default).min(
66 *self
67 .substitution
68 .get(c)
69 .map(|subs| {
71 subs.values()
72 .max_by(|&a, &b| a.partial_cmp(b).unwrap())
73 .unwrap()
74 })
75 .unwrap_or(&self.substitution_default),
76 ),
77 )
78 })
79 .min_by(|a, b| a.partial_cmp(b).unwrap())
80 .unwrap()
81 }
82
83 pub fn distance(&self, s1: &[char], s2: &[char]) -> f64 {
86 let n1 = s1.len();
87 let n2 = s2.len();
88 if n1 == 0 && n2 == 0 {
89 return 0.0;
90 }
91
92 let mut dp: Vec<Vec<f64>> = vec![vec![0.0; n2 + 1]; n1 + 1];
93 for i in 0..=n1 {
94 for j in 0..=n2 {
95 if i == 0 && j == 0 {
96 continue;
97 }
98
99 if i == 0 {
100 let c = &s2[j - 1];
102 dp[i][j] = *self.insertion.get(c).unwrap_or(&self.insertion_default);
103 dp[i][j] += dp[i][j - 1];
104 } else if j == 0 {
105 let c = &s1[i - 1];
107 dp[i][j] = *self.deletion.get(c).unwrap_or(&self.deletion_default);
108 dp[i][j] += dp[i - 1][j];
109 } else {
110 let c1 = &s1[i - 1];
111 let c2 = &s2[j - 1];
112 let insert_cost = self.insertion.get(c2).unwrap_or(&self.insertion_default);
113 let delete_cost = self.deletion.get(c1).unwrap_or(&self.deletion_default);
114 let substitute_cost = self
115 .substitution
116 .get(c1)
117 .map(|subs| subs.get(c2).unwrap_or(&self.substitution_default))
118 .unwrap_or(&self.substitution_default);
119
120 if c1 == c2 {
121 dp[i][j] = dp[i - 1][j - 1];
122 } else {
123 dp[i][j] = (dp[i][j - 1] + insert_cost).min(
124 (dp[i - 1][j] + delete_cost).min(dp[i - 1][j - 1] + substitute_cost),
125 );
126 }
127 }
128 }
129 }
130 return dp[n1][n2];
131 }
132
133 pub fn similarity(&self, s1: &[char], s2: &[char]) -> Result<f64, StrSimError> {
140 let max_cost = self.compute_max_cost(&s1).max(self.compute_max_cost(&s2)) as f64;
141 let min_lev: f64;
142
143 if self.lowerbound > 0.0 {
144 let diff = s1.len().abs_diff(s2.len()) as f64;
145 if s1.len() == 0 && s2.len() == 0 {
146 return Ok(1.0);
147 }
148 if s1.len() == 0 {
149 min_lev = diff * self.estimate_min_char_cost(&s2) as f64;
150 } else if s2.len() == 0 {
151 min_lev = diff * self.estimate_min_char_cost(&s1) as f64;
152 } else {
153 min_lev = diff
154 * self
155 .estimate_min_char_cost(&s1)
156 .min(self.estimate_min_char_cost(&s2)) as f64;
157 }
158 let est_sim = 1.0 - (min_lev / max_cost);
159 if est_sim < self.lowerbound {
160 return Ok(0.0);
161 }
162 }
163
164 let lev = self.distance(&s1, &s2) as f64;
165 if max_cost < lev {
166 return Err(StrSimError::InvalidConfigData(
167 "Illegal value of operation costs".to_owned(),
168 ));
169 }
170
171 if max_cost == 0.0 {
172 return Ok(1.0);
173 }
174
175 let lev_sim = 1.0 - (lev / max_cost);
176 if self.lowerbound > 0.0 && lev_sim < self.lowerbound {
177 return Ok(0.0);
178 }
179 Ok(lev_sim)
180 }
181}
182
183impl ExpectTokenizerType for Levenshtein {
184 fn get_expected_tokenizer_type(&self) -> TokenizerType {
185 TokenizerType::Seq(Box::new(None))
186 }
187}
188
189impl StrSim<Vec<char>> for Levenshtein {
190 fn similarity_pre_tok2(&self, s1: &Vec<char>, s2: &Vec<char>) -> Result<f64, StrSimError> {
191 self.similarity(s1, s2)
192 }
193}