1use std::cmp::min;
35use std::collections::HashMap;
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum EditOp {
40 Delete(char),
42 Insert(char),
44 Substitute(char, char),
46 Transpose(char, char),
48}
49
50#[derive(Debug, Clone)]
52pub struct ErrorModel {
53 pub p_deletion: f64,
55 pub p_insertion: f64,
57 pub p_substitution: f64,
59 pub p_transposition: f64,
61 _char_confusion: HashMap<(char, char), f64>,
63 max_edit_distance: usize,
65}
66
67impl Default for ErrorModel {
68 fn default() -> Self {
69 Self {
70 p_deletion: 0.25,
71 p_insertion: 0.25,
72 p_substitution: 0.25,
73 p_transposition: 0.25,
74 _char_confusion: HashMap::new(),
75 max_edit_distance: 2, }
77 }
78}
79
80impl ErrorModel {
81 pub fn new(
83 p_deletion: f64,
84 p_insertion: f64,
85 p_substitution: f64,
86 p_transposition: f64,
87 ) -> Self {
88 let total = p_deletion + p_insertion + p_substitution + p_transposition;
90 Self {
91 p_deletion: p_deletion / total,
92 p_insertion: p_insertion / total,
93 p_substitution: p_substitution / total,
94 p_transposition: p_transposition / total,
95 _char_confusion: HashMap::new(),
96 max_edit_distance: 2,
97 }
98 }
99
100 pub fn with_max_distance(mut self, maxdistance: usize) -> Self {
102 self.max_edit_distance = maxdistance;
103 self
104 }
105
106 pub fn error_probability(&self, typo: &str, correct: &str) -> f64 {
108 if typo == correct {
110 return 1.0;
111 }
112
113 let edit_distance = self.min_edit_operations(typo, correct);
115
116 match edit_distance.len() {
117 0 => 1.0, 1 => {
119 match edit_distance[0] {
121 EditOp::Delete(_) => self.p_deletion,
122 EditOp::Insert(_) => self.p_insertion,
123 EditOp::Substitute(_, _) => self.p_substitution,
124 EditOp::Transpose(_, _) => self.p_transposition,
125 }
126 }
127 n => {
128 let base_prob = 0.1f64.powi(n as i32 - 1);
130 let mut prob = base_prob;
131
132 for op in &edit_distance {
133 match op {
134 EditOp::Delete(_) => prob *= self.p_deletion,
135 EditOp::Insert(_) => prob *= self.p_insertion,
136 EditOp::Substitute(_, _) => prob *= self.p_substitution,
137 EditOp::Transpose(_, _) => prob *= self.p_transposition,
138 }
139 }
140
141 prob
142 }
143 }
144 }
145
146 pub fn min_edit_operations(&self, typo: &str, correct: &str) -> Vec<EditOp> {
148 let typo_chars: Vec<char> = typo.chars().collect();
149 let correct_chars: Vec<char> = correct.chars().collect();
150
151 if typo == correct {
153 return vec![];
154 }
155
156 if (typo_chars.len() as isize - correct_chars.len() as isize).abs()
158 > self.max_edit_distance as isize
159 {
160 return vec![EditOp::Substitute('?', '?')];
162 }
163
164 if correct_chars.len() == typo_chars.len() + 1 {
166 for i in 0..correct_chars.len() {
168 let mut test_chars = correct_chars.clone();
169 test_chars.remove(i);
170 if test_chars == typo_chars {
171 return vec![EditOp::Delete(correct_chars[i])];
172 }
173 }
174 } else if correct_chars.len() + 1 == typo_chars.len() {
175 for i in 0..typo_chars.len() {
177 let mut test_chars = typo_chars.clone();
178 test_chars.remove(i);
179 if test_chars == correct_chars {
180 return vec![EditOp::Insert(typo_chars[i])];
181 }
182 }
183 } else if correct_chars.len() == typo_chars.len() {
184 let mut diff_positions = Vec::new();
186
187 for i in 0..correct_chars.len() {
188 if correct_chars[i] != typo_chars[i] {
189 diff_positions.push(i);
190 }
191 }
192
193 if diff_positions.len() == 1 {
194 let i = diff_positions[0];
196 return vec![EditOp::Substitute(correct_chars[i], typo_chars[i])];
197 } else if diff_positions.len() == 2 && diff_positions[0] + 1 == diff_positions[1] {
198 let i = diff_positions[0];
199
200 if correct_chars[i] == typo_chars[i + 1] && correct_chars[i + 1] == typo_chars[i] {
202 return vec![EditOp::Transpose(correct_chars[i], correct_chars[i + 1])];
203 }
204 }
205 }
206
207 let mut operations = Vec::new();
209 let _distance = self.levenshtein_with_ops_efficient(correct, typo, &mut operations);
210 operations
211 }
212
213 fn levenshtein_with_ops_efficient(
216 &self,
217 s1: &str,
218 s2: &str,
219 operations: &mut Vec<EditOp>,
220 ) -> usize {
221 let chars1: Vec<char> = s1.chars().collect();
222 let chars2: Vec<char> = s2.chars().collect();
223 let len1 = chars1.len();
224 let len2 = chars2.len();
225
226 if s1 == s2 {
228 return 0;
229 }
230
231 if (len1 as isize - len2 as isize).abs() > self.max_edit_distance as isize {
233 return self.max_edit_distance + 1; }
235
236 let mut prev_row = (0..=len2).collect::<Vec<_>>();
238 let mut curr_row = vec![0; len2 + 1];
239
240 let mut op_matrix = vec![vec![0; len2 + 1]; len1 + 1];
243
244 for j in 1..=len2 {
246 op_matrix[0][j] = 1; }
248
249 for i in 1..=len1 {
250 curr_row[0] = i;
251 op_matrix[i][0] = 2; for j in 1..=len2 {
254 let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
255
256 let del_cost = prev_row[j] + 1;
258 let ins_cost = curr_row[j - 1] + 1;
259 let sub_cost = prev_row[j - 1] + cost;
260
261 curr_row[j] = min(min(del_cost, ins_cost), sub_cost);
263
264 if curr_row[j] == del_cost {
266 op_matrix[i][j] = 2; } else if curr_row[j] == ins_cost {
268 op_matrix[i][j] = 1; } else if cost > 0 {
270 op_matrix[i][j] = 3; } else {
272 op_matrix[i][j] = 0; }
274
275 if i > 1
277 && j > 1
278 && chars1[i - 1] == chars2[j - 2]
279 && chars1[i - 2] == chars2[j - 1]
280 {
281 let trans_cost = prev_row[j - 2] + 1;
282 if trans_cost < curr_row[j] {
283 curr_row[j] = trans_cost;
284 op_matrix[i][j] = 4; }
286 }
287 }
288
289 if curr_row.iter().all(|&c| c > self.max_edit_distance) {
291 return self.max_edit_distance + 1;
292 }
293
294 std::mem::swap(&mut prev_row, &mut curr_row);
296 }
297
298 let mut i = len1;
300 let mut j = len2;
301 let mut backtrack_ops = Vec::new();
302
303 while i > 0 || j > 0 {
304 match if i == 0 || j == 0 {
305 if i == 0 {
306 1
307 } else {
308 2
309 } } else {
311 op_matrix[i][j]
312 } {
313 0 => {
314 i -= 1;
316 j -= 1;
317 }
318 1 => {
319 j -= 1;
321 backtrack_ops.push(EditOp::Insert(chars2[j]));
322 }
323 2 => {
324 i -= 1;
326 backtrack_ops.push(EditOp::Delete(chars1[i]));
327 }
328 3 => {
329 i -= 1;
331 j -= 1;
332 backtrack_ops.push(EditOp::Substitute(chars1[i], chars2[j]));
333 }
334 4 => {
335 i -= 2;
337 j -= 2;
338 backtrack_ops.push(EditOp::Transpose(chars1[i + 1], chars1[i + 2]));
339 }
340 _ => break, }
342 }
343
344 backtrack_ops.reverse();
346 operations.extend(backtrack_ops);
347
348 prev_row[len2]
350 }
351
352 pub fn levenshtein_with_ops(&self, s1: &str, s2: &str, operations: &mut Vec<EditOp>) -> usize {
354 let chars1: Vec<char> = s1.chars().collect();
355 let chars2: Vec<char> = s2.chars().collect();
356 let len1 = chars1.len();
357 let len2 = chars2.len();
358
359 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
361
362 for (i, row) in matrix.iter_mut().enumerate().take(len1 + 1) {
364 row[0] = i;
365 }
366
367 for j in 0..=len2 {
368 matrix[0][j] = j;
369 }
370
371 for i in 1..=len1 {
373 for j in 1..=len2 {
374 let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
375
376 matrix[i][j] = min(
377 min(
378 matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
381 matrix[i - 1][j - 1] + cost, );
383
384 if i > 1
386 && j > 1
387 && chars1[i - 1] == chars2[j - 2]
388 && chars1[i - 2] == chars2[j - 1]
389 {
390 matrix[i][j] = min(
391 matrix[i][j],
392 matrix[i - 2][j - 2] + 1, );
394 }
395 }
396 }
397
398 let mut i = len1;
400 let mut j = len2;
401
402 let mut temp_ops = Vec::new();
404
405 while i > 0 || j > 0 {
406 if i > 0 && j > 0 && chars1[i - 1] == chars2[j - 1] {
407 i -= 1;
409 j -= 1;
410 } else if i > 1
411 && j > 1
412 && chars1[i - 1] == chars2[j - 2]
413 && chars1[i - 2] == chars2[j - 1]
414 && matrix[i][j] == matrix[i - 2][j - 2] + 1
415 {
416 temp_ops.push(EditOp::Transpose(chars1[i - 2], chars1[i - 1]));
418 i -= 2;
419 j -= 2;
420 } else if i > 0 && j > 0 && matrix[i][j] == matrix[i - 1][j - 1] + 1 {
421 temp_ops.push(EditOp::Substitute(chars1[i - 1], chars2[j - 1]));
423 i -= 1;
424 j -= 1;
425 } else if i > 0 && matrix[i][j] == matrix[i - 1][j] + 1 {
426 temp_ops.push(EditOp::Delete(chars1[i - 1]));
428 i -= 1;
429 } else if j > 0 && matrix[i][j] == matrix[i][j - 1] + 1 {
430 temp_ops.push(EditOp::Insert(chars2[j - 1]));
432 j -= 1;
433 } else {
434 break;
436 }
437 }
438
439 temp_ops.reverse();
441 operations.extend(temp_ops);
442
443 matrix[len1][len2]
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_error_model() {
453 let error_model = ErrorModel::default();
454
455 let p_deletion = error_model.error_probability("cat", "cart"); let p_insertion = error_model.error_probability("cart", "cat"); let p_substitution = error_model.error_probability("cat", "cut"); let p_transposition = error_model.error_probability("form", "from"); assert!(p_deletion > 0.0);
463 assert!(p_insertion > 0.0);
464 assert!(p_substitution > 0.0);
465 assert!(p_transposition > 0.0);
466
467 assert_eq!(error_model.error_probability("word", "word"), 1.0);
469 }
470
471 #[test]
472 fn test_edit_operations() {
473 let error_model = ErrorModel::default();
474
475 let ops = error_model.min_edit_operations("cat", "cart");
477 assert_eq!(ops.len(), 1);
478 assert!(matches!(ops[0], EditOp::Delete('r')));
479
480 let ops = error_model.min_edit_operations("cart", "cat");
482 assert_eq!(ops.len(), 1);
483 assert!(matches!(ops[0], EditOp::Insert('r')));
484
485 let ops = error_model.min_edit_operations("cut", "cat");
487 assert_eq!(ops.len(), 1);
488 assert!(matches!(ops[0], EditOp::Substitute('a', 'u')));
489
490 let ops = error_model.min_edit_operations("from", "form");
492 assert_eq!(ops.len(), 1);
493 assert!(matches!(ops[0], EditOp::Transpose('o', 'r')));
494 }
495
496 #[test]
497 fn test_efficient_levenshtein() {
498 let error_model = ErrorModel::default();
499
500 let mut ops1 = Vec::new();
502 let mut ops2 = Vec::new();
503 let dist1 = error_model.levenshtein_with_ops("hello", "hello", &mut ops1);
504 let dist2 = error_model.levenshtein_with_ops_efficient("hello", "hello", &mut ops2);
505 assert_eq!(dist1, 0);
506 assert_eq!(dist2, 0);
507 assert!(ops1.is_empty());
508 assert!(ops2.is_empty());
509
510 let test_cases = [
512 ("cat", "bat"), ("cat", "cats"), ("cats", "cat"), ];
516
517 for (s1, s2) in test_cases {
518 let mut ops1 = Vec::new();
519 let mut ops2 = Vec::new();
520 let dist1 = error_model.levenshtein_with_ops(s1, s2, &mut ops1);
521 let dist2 = error_model.levenshtein_with_ops_efficient(s1, s2, &mut ops2);
522
523 assert_eq!(dist1, 1);
525 assert_eq!(dist2, 1);
526 }
527
528 let mut ops1 = Vec::new();
531 let mut ops2 = Vec::new();
532 error_model.levenshtein_with_ops("abc", "acb", &mut ops1);
533 error_model.levenshtein_with_ops_efficient("abc", "acb", &mut ops2);
534 assert!(ops1.len() <= 2); assert!(ops2.len() <= 2);
536
537 let mut ops1 = Vec::new();
539 let mut ops2 = Vec::new();
540 let dist1 = error_model.levenshtein_with_ops("programming", "programmer", &mut ops1);
541 let dist2 =
542 error_model.levenshtein_with_ops_efficient("programming", "programmer", &mut ops2);
543 assert!(dist1 <= 3); assert!(dist2 <= 3);
545 }
546
547 #[test]
548 fn test_early_termination() {
549 let error_model = ErrorModel::default().with_max_distance(1);
551
552 let ops = error_model.min_edit_operations("cat", "dog");
554
555 if !ops.is_empty() {
558 assert!(matches!(ops[0], EditOp::Substitute(_, _)) || ops.len() > 1);
560 }
561
562 let error_model = ErrorModel::default().with_max_distance(3);
564
565 let ops = error_model.min_edit_operations("kitten", "sitting");
567 assert!(!ops.is_empty()); let ops = error_model.min_edit_operations("algorithm", "logarithm");
571 if ops.len() == 1 {
574 assert!(matches!(ops[0], EditOp::Substitute(_, _)));
576 } else {
577 assert!(!ops.is_empty());
579 }
580 }
581}